12 Commits

31 changed files with 2556 additions and 204 deletions

View File

@@ -32,9 +32,9 @@ jobs:
go-version: "1.25.0" go-version: "1.25.0"
- name: lint - name: lint
uses: golangci/golangci-lint-action@v6 uses: golangci/golangci-lint-action@v8
with: with:
version: latest version: v2.5.0
nix-build: nix-build:
strategy: strategy:
@@ -44,9 +44,7 @@ jobs:
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: cachix/install-nix-action@v30 - uses: DeterminateSystems/nix-installer-action@v17
with:
github_access_token: ${{ secrets.GITHUB_TOKEN }}
- uses: nix-community/cache-nix-action@v6 - uses: nix-community/cache-nix-action@v6
with: with:

View File

@@ -7,6 +7,7 @@ on:
permissions: permissions:
contents: write contents: write
packages: write
jobs: jobs:
release-linux: release-linux:
@@ -48,3 +49,78 @@ jobs:
args: release --clean --config .goreleaser-darwin.yaml --skip=validate args: release --clean --config .goreleaser-darwin.yaml --skip=validate
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
release-containers:
runs-on: ubuntu-latest
strategy:
matrix:
variant: [alpine, debian, ubuntu, scratch]
include:
- variant: alpine
is_default: true
- variant: debian
is_default: false
- variant: ubuntu
is_default: false
- variant: scratch
is_default: false
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- uses: DeterminateSystems/nix-installer-action@v17
- uses: nix-community/cache-nix-action@v6
with:
primary-key: nix-${{ runner.os }}-${{ hashFiles('flake.lock') }}
restore-prefixes-first-match: nix-${{ runner.os }}-
- name: login to ghcr
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: build container
run: nix build ".#snitch-${{ matrix.variant }}" --print-out-paths
- name: load and push container
env:
VERSION: ${{ github.ref_name }}
REPO: ghcr.io/${{ github.repository_owner }}/snitch
VARIANT: ${{ matrix.variant }}
IS_DEFAULT: ${{ matrix.is_default }}
run: |
VERSION="${VERSION#v}"
docker load < result
IMAGE_TAG=$(docker images --format '{{.Repository}}:{{.Tag}}' | grep "snitch:.*-${VARIANT}" | head -1)
if [ -z "$IMAGE_TAG" ]; then
echo "error: could not find loaded image for ${VARIANT}"
exit 1
fi
docker tag "$IMAGE_TAG" "${REPO}:${VERSION}-${VARIANT}"
docker tag "$IMAGE_TAG" "${REPO}:latest-${VARIANT}"
docker push "${REPO}:${VERSION}-${VARIANT}"
docker push "${REPO}:latest-${VARIANT}"
if [ "$IS_DEFAULT" = "true" ]; then
docker tag "$IMAGE_TAG" "${REPO}:${VERSION}"
docker tag "$IMAGE_TAG" "${REPO}:latest"
docker push "${REPO}:${VERSION}"
docker push "${REPO}:latest"
fi
- name: summary
env:
VERSION: ${{ github.ref_name }}
REPO: ghcr.io/${{ github.repository_owner }}/snitch
VARIANT: ${{ matrix.variant }}
run: |
VERSION="${VERSION#v}"
echo "pushed ${REPO}:${VERSION}-${VARIANT}" >> $GITHUB_STEP_SUMMARY

View File

@@ -82,6 +82,7 @@ aurs:
package: |- package: |-
install -Dm755 "./snitch" "${pkgdir}/usr/bin/snitch" install -Dm755 "./snitch" "${pkgdir}/usr/bin/snitch"
install -Dm644 "./LICENSE" "${pkgdir}/usr/share/licenses/snitch/LICENSE" install -Dm644 "./LICENSE" "${pkgdir}/usr/share/licenses/snitch/LICENSE"
install -Dm644 "./README.md" "${pkgdir}/usr/share/doc/snitch/README.md"
commit_msg_template: "Update to {{ .Tag }}" commit_msg_template: "Update to {{ .Tag }}"
skip_upload: auto skip_upload: auto

113
README.md
View File

@@ -6,13 +6,29 @@ a friendlier `ss` / `netstat` for humans. inspect network connections with a cle
## install ## install
### homebrew
```bash
brew install snitch
```
> thanks to [@bevanjkay](https://github.com/bevanjkay) for adding snitch to homebrew-core
### go ### go
```bash ```bash
go install github.com/karol-broda/snitch@latest go install github.com/karol-broda/snitch@latest
``` ```
### nixos / nix ### nixpkgs
```bash
nix-env -iA nixpkgs.snitch
```
> thanks to [@DieracDelta](https://github.com/DieracDelta) for adding snitch to nixpkgs
### nixos / nix (flake)
```bash ```bash
# try it # try it
@@ -28,6 +44,45 @@ nix profile install github:karol-broda/snitch
# then use: inputs.snitch.packages.${system}.default # then use: inputs.snitch.packages.${system}.default
``` ```
### home-manager (flake)
add snitch to your flake inputs and import the home-manager module:
```nix
{
inputs = {
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
home-manager.url = "github:nix-community/home-manager";
snitch.url = "github:karol-broda/snitch";
};
outputs = { nixpkgs, home-manager, snitch, ... }: {
homeConfigurations."user" = home-manager.lib.homeManagerConfiguration {
pkgs = nixpkgs.legacyPackages.x86_64-linux;
modules = [
snitch.homeManagerModules.default
{
programs.snitch = {
enable = true;
# optional: use the flake's package instead of nixpkgs
# package = snitch.packages.x86_64-linux.default;
settings = {
defaults = {
theme = "catppuccin-mocha";
interval = "2s";
resolve = true;
};
};
};
}
];
};
};
}
```
available themes: `ansi`, `catppuccin-mocha`, `catppuccin-macchiato`, `catppuccin-frappe`, `catppuccin-latte`, `gruvbox-dark`, `gruvbox-light`, `dracula`, `nord`, `tokyo-night`, `tokyo-night-storm`, `tokyo-night-light`, `solarized-dark`, `solarized-light`, `one-dark`, `mono`
### arch linux (aur) ### arch linux (aur)
```bash ```bash
@@ -52,6 +107,47 @@ curl -sSL https://raw.githubusercontent.com/karol-broda/snitch/master/install.sh
> **macos:** the install script automatically removes the quarantine attribute (`com.apple.quarantine`) from the binary to allow it to run without gatekeeper warnings. to disable this, set `KEEP_QUARANTINE=1`. > **macos:** the install script automatically removes the quarantine attribute (`com.apple.quarantine`) from the binary to allow it to run without gatekeeper warnings. to disable this, set `KEEP_QUARANTINE=1`.
### docker
pre-built oci images available from github container registry:
```bash
# pull from ghcr.io
docker pull ghcr.io/karol-broda/snitch:latest # alpine (default)
docker pull ghcr.io/karol-broda/snitch:latest-alpine # alpine (~17MB)
docker pull ghcr.io/karol-broda/snitch:latest-scratch # minimal, binary only (~9MB)
docker pull ghcr.io/karol-broda/snitch:latest-debian # debian trixie
docker pull ghcr.io/karol-broda/snitch:latest-ubuntu # ubuntu 24.04
# or use a specific version
docker pull ghcr.io/karol-broda/snitch:0.2.0-alpine
```
alternatively, build locally via nix flake:
```bash
nix build github:karol-broda/snitch#snitch-alpine
docker load < result
```
**running the container:**
```bash
# basic usage - sees host sockets but not process names
docker run --rm --net=host snitch:latest ls
# full info - includes PID, process name, user
docker run --rm --net=host --pid=host --cap-add=SYS_PTRACE snitch:latest ls
```
| flag | purpose |
|------|---------|
| `--net=host` | share host network namespace (required to see host connections) |
| `--pid=host` | share host pid namespace (needed for process info) |
| `--cap-add=SYS_PTRACE` | read process details from `/proc/<pid>` |
> **note:** `CAP_NET_ADMIN` and `CAP_NET_RAW` are not required. snitch reads from `/proc/net/*` which doesn't need special network capabilities.
### binary ### binary
download from [releases](https://github.com/karol-broda/snitch/releases): download from [releases](https://github.com/karol-broda/snitch/releases):
@@ -222,8 +318,23 @@ optional config file at `~/.config/snitch/snitch.toml`:
numeric = false # disable name resolution numeric = false # disable name resolution
dns_cache = true # cache dns lookups (set to false to disable) dns_cache = true # cache dns lookups (set to false to disable)
theme = "auto" # color theme: auto, dark, light, mono theme = "auto" # color theme: auto, dark, light, mono
[tui]
remember_state = false # remember view options between sessions
``` ```
### remembering view options
when `remember_state = true`, the tui will save and restore:
- filter toggles (tcp/udp, listen/established/other)
- sort field and direction
- address and port resolution settings
state is saved to `$XDG_STATE_HOME/snitch/tui.json` (defaults to `~/.local/state/snitch/tui.json`).
cli flags always take priority over saved state.
### environment variables ### environment variables
```bash ```bash

View File

@@ -6,6 +6,7 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/karol-broda/snitch/internal/errutil"
"github.com/karol-broda/snitch/internal/testutil" "github.com/karol-broda/snitch/internal/testutil"
) )
@@ -407,16 +408,16 @@ func TestEnvironmentVariables(t *testing.T) {
oldEnvVars := make(map[string]string) oldEnvVars := make(map[string]string)
for key, value := range tt.envVars { for key, value := range tt.envVars {
oldEnvVars[key] = os.Getenv(key) oldEnvVars[key] = os.Getenv(key)
os.Setenv(key, value) errutil.Setenv(key, value)
} }
// Clean up environment variables // Clean up environment variables
defer func() { defer func() {
for key, oldValue := range oldEnvVars { for key, oldValue := range oldEnvVars {
if oldValue == "" { if oldValue == "" {
os.Unsetenv(key) errutil.Unsetenv(key)
} else { } else {
os.Setenv(key, oldValue) errutil.Setenv(key, oldValue)
} }
} }
}() }()

View File

@@ -8,16 +8,18 @@ import (
"log" "log"
"os" "os"
"os/exec" "os/exec"
"github.com/karol-broda/snitch/internal/collector"
"github.com/karol-broda/snitch/internal/color"
"github.com/karol-broda/snitch/internal/config"
"github.com/karol-broda/snitch/internal/resolver"
"strconv" "strconv"
"strings" "strings"
"text/tabwriter" "text/tabwriter"
"github.com/charmbracelet/lipgloss" "github.com/charmbracelet/lipgloss"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/karol-broda/snitch/internal/collector"
"github.com/karol-broda/snitch/internal/color"
"github.com/karol-broda/snitch/internal/config"
"github.com/karol-broda/snitch/internal/errutil"
"github.com/karol-broda/snitch/internal/resolver"
"github.com/tidwall/pretty" "github.com/tidwall/pretty"
"golang.org/x/term" "golang.org/x/term"
) )
@@ -25,6 +27,7 @@ import (
// ls-specific flags // ls-specific flags
var ( var (
outputFormat string outputFormat string
outputFile string
noHeaders bool noHeaders bool
showTimestamp bool showTimestamp bool
sortBy string sortBy string
@@ -70,9 +73,77 @@ func runListCommand(outputFormat string, args []string) {
selectedFields = strings.Split(fields, ",") selectedFields = strings.Split(fields, ",")
} }
// handle file output
if outputFile != "" {
writeToFile(rt.Connections, outputFile, selectedFields)
return
}
renderList(rt.Connections, outputFormat, selectedFields) renderList(rt.Connections, outputFormat, selectedFields)
} }
func writeToFile(connections []collector.Connection, filename string, selectedFields []string) {
file, err := os.Create(filename)
if err != nil {
log.Fatalf("failed to create file: %v", err)
}
defer errutil.Close(file)
// determine format from extension
format := "csv"
lowerFilename := strings.ToLower(filename)
if strings.HasSuffix(lowerFilename, ".json") {
format = "json"
} else if strings.HasSuffix(lowerFilename, ".tsv") {
format = "tsv"
}
if len(selectedFields) == 0 {
selectedFields = []string{"pid", "process", "user", "proto", "state", "laddr", "lport", "raddr", "rport"}
if showTimestamp {
selectedFields = append([]string{"ts"}, selectedFields...)
}
}
switch format {
case "json":
encoder := json.NewEncoder(file)
encoder.SetIndent("", " ")
if err := encoder.Encode(connections); err != nil {
log.Fatalf("failed to write JSON: %v", err)
}
case "tsv":
writeDelimited(file, connections, "\t", !noHeaders, selectedFields)
default:
writeDelimited(file, connections, ",", !noHeaders, selectedFields)
}
fmt.Fprintf(os.Stderr, "exported %d connections to %s\n", len(connections), filename)
}
func writeDelimited(w io.Writer, connections []collector.Connection, delimiter string, headers bool, selectedFields []string) {
if headers {
headerRow := make([]string, len(selectedFields))
for i, field := range selectedFields {
headerRow[i] = strings.ToUpper(field)
}
_, _ = fmt.Fprintln(w, strings.Join(headerRow, delimiter))
}
for _, conn := range connections {
fieldMap := getFieldMap(conn)
row := make([]string, len(selectedFields))
for i, field := range selectedFields {
val := fieldMap[field]
if delimiter == "," && (strings.Contains(val, ",") || strings.Contains(val, "\"") || strings.Contains(val, "\n")) {
val = "\"" + strings.ReplaceAll(val, "\"", "\"\"") + "\""
}
row[i] = val
}
_, _ = fmt.Fprintln(w, strings.Join(row, delimiter))
}
}
func renderList(connections []collector.Connection, format string, selectedFields []string) { func renderList(connections []collector.Connection, format string, selectedFields []string) {
switch format { switch format {
case "json": case "json":
@@ -120,6 +191,8 @@ func getFieldMap(c collector.Connection) map[string]string {
return map[string]string{ return map[string]string{
"pid": strconv.Itoa(c.PID), "pid": strconv.Itoa(c.PID),
"process": c.Process, "process": c.Process,
"cmdline": c.Cmdline,
"cwd": c.Cwd,
"user": c.User, "user": c.User,
"uid": strconv.Itoa(c.UID), "uid": strconv.Itoa(c.UID),
"proto": c.Proto, "proto": c.Proto,
@@ -185,7 +258,7 @@ func printCSV(conns []collector.Connection, headers bool, timestamp bool, select
func printPlainTable(conns []collector.Connection, headers bool, timestamp bool, selectedFields []string) { func printPlainTable(conns []collector.Connection, headers bool, timestamp bool, selectedFields []string) {
w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0) w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0)
defer w.Flush() defer errutil.Flush(w)
if len(selectedFields) == 0 { if len(selectedFields) == 0 {
selectedFields = []string{"pid", "process", "user", "proto", "state", "laddr", "lport", "raddr", "rport"} selectedFields = []string{"pid", "process", "user", "proto", "state", "laddr", "lport", "raddr", "rport"}
@@ -199,7 +272,7 @@ func printPlainTable(conns []collector.Connection, headers bool, timestamp bool,
for _, field := range selectedFields { for _, field := range selectedFields {
headerRow = append(headerRow, strings.ToUpper(field)) headerRow = append(headerRow, strings.ToUpper(field))
} }
fmt.Fprintln(w, strings.Join(headerRow, "\t")) errutil.Ignore(fmt.Fprintln(w, strings.Join(headerRow, "\t")))
} }
for _, conn := range conns { for _, conn := range conns {
@@ -208,7 +281,7 @@ func printPlainTable(conns []collector.Connection, headers bool, timestamp bool,
for _, field := range selectedFields { for _, field := range selectedFields {
row = append(row, fieldMap[field]) row = append(row, fieldMap[field])
} }
fmt.Fprintln(w, strings.Join(row, "\t")) errutil.Ignore(fmt.Fprintln(w, strings.Join(row, "\t")))
} }
} }
@@ -393,6 +466,7 @@ func init() {
// ls-specific flags // ls-specific flags
lsCmd.Flags().StringVarP(&outputFormat, "output", "o", cfg.Defaults.OutputFormat, "Output format (table, wide, json, csv)") lsCmd.Flags().StringVarP(&outputFormat, "output", "o", cfg.Defaults.OutputFormat, "Output format (table, wide, json, csv)")
lsCmd.Flags().StringVarP(&outputFile, "output-file", "O", "", "Write output to file (format detected from extension: .csv, .tsv, .json)")
lsCmd.Flags().BoolVar(&noHeaders, "no-headers", cfg.Defaults.NoHeaders, "Omit headers for table/csv output") lsCmd.Flags().BoolVar(&noHeaders, "no-headers", cfg.Defaults.NoHeaders, "Omit headers for table/csv output")
lsCmd.Flags().BoolVar(&showTimestamp, "ts", false, "Include timestamp in output") lsCmd.Flags().BoolVar(&showTimestamp, "ts", false, "Include timestamp in output")
lsCmd.Flags().StringVarP(&sortBy, "sort", "s", cfg.Defaults.SortBy, "Sort by column (e.g., pid:desc)") lsCmd.Flags().StringVarP(&sortBy, "sort", "s", cfg.Defaults.SortBy, "Sort by column (e.g., pid:desc)")

View File

@@ -8,7 +8,6 @@ import (
"log" "log"
"os" "os"
"os/signal" "os/signal"
"github.com/karol-broda/snitch/internal/collector"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@@ -17,6 +16,9 @@ import (
"time" "time"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/karol-broda/snitch/internal/collector"
"github.com/karol-broda/snitch/internal/errutil"
) )
type StatsData struct { type StatsData struct {
@@ -227,19 +229,19 @@ func printStatsCSV(stats *StatsData, headers bool) {
func printStatsTable(stats *StatsData, headers bool) { func printStatsTable(stats *StatsData, headers bool) {
w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0) w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0)
defer w.Flush() defer errutil.Flush(w)
if headers { if headers {
fmt.Fprintf(w, "TIMESTAMP\t%s\n", stats.Timestamp.Format(time.RFC3339)) errutil.Ignore(fmt.Fprintf(w, "TIMESTAMP\t%s\n", stats.Timestamp.Format(time.RFC3339)))
fmt.Fprintf(w, "TOTAL CONNECTIONS\t%d\n", stats.Total) errutil.Ignore(fmt.Fprintf(w, "TOTAL CONNECTIONS\t%d\n", stats.Total))
fmt.Fprintln(w) errutil.Ignore(fmt.Fprintln(w))
} }
// Protocol breakdown // Protocol breakdown
if len(stats.ByProto) > 0 { if len(stats.ByProto) > 0 {
if headers { if headers {
fmt.Fprintln(w, "BY PROTOCOL:") errutil.Ignore(fmt.Fprintln(w, "BY PROTOCOL:"))
fmt.Fprintln(w, "PROTO\tCOUNT") errutil.Ignore(fmt.Fprintln(w, "PROTO\tCOUNT"))
} }
protocols := make([]string, 0, len(stats.ByProto)) protocols := make([]string, 0, len(stats.ByProto))
for proto := range stats.ByProto { for proto := range stats.ByProto {
@@ -247,16 +249,16 @@ func printStatsTable(stats *StatsData, headers bool) {
} }
sort.Strings(protocols) sort.Strings(protocols)
for _, proto := range protocols { for _, proto := range protocols {
fmt.Fprintf(w, "%s\t%d\n", strings.ToUpper(proto), stats.ByProto[proto]) errutil.Ignore(fmt.Fprintf(w, "%s\t%d\n", strings.ToUpper(proto), stats.ByProto[proto]))
} }
fmt.Fprintln(w) errutil.Ignore(fmt.Fprintln(w))
} }
// State breakdown // State breakdown
if len(stats.ByState) > 0 { if len(stats.ByState) > 0 {
if headers { if headers {
fmt.Fprintln(w, "BY STATE:") errutil.Ignore(fmt.Fprintln(w, "BY STATE:"))
fmt.Fprintln(w, "STATE\tCOUNT") errutil.Ignore(fmt.Fprintln(w, "STATE\tCOUNT"))
} }
states := make([]string, 0, len(stats.ByState)) states := make([]string, 0, len(stats.ByState))
for state := range stats.ByState { for state := range stats.ByState {
@@ -264,16 +266,16 @@ func printStatsTable(stats *StatsData, headers bool) {
} }
sort.Strings(states) sort.Strings(states)
for _, state := range states { for _, state := range states {
fmt.Fprintf(w, "%s\t%d\n", state, stats.ByState[state]) errutil.Ignore(fmt.Fprintf(w, "%s\t%d\n", state, stats.ByState[state]))
} }
fmt.Fprintln(w) errutil.Ignore(fmt.Fprintln(w))
} }
// Process breakdown (top 10) // Process breakdown (top 10)
if len(stats.ByProc) > 0 { if len(stats.ByProc) > 0 {
if headers { if headers {
fmt.Fprintln(w, "BY PROCESS (TOP 10):") errutil.Ignore(fmt.Fprintln(w, "BY PROCESS (TOP 10):"))
fmt.Fprintln(w, "PID\tPROCESS\tCOUNT") errutil.Ignore(fmt.Fprintln(w, "PID\tPROCESS\tCOUNT"))
} }
limit := 10 limit := 10
if len(stats.ByProc) < limit { if len(stats.ByProc) < limit {
@@ -281,7 +283,7 @@ func printStatsTable(stats *StatsData, headers bool) {
} }
for i := 0; i < limit; i++ { for i := 0; i < limit; i++ {
proc := stats.ByProc[i] proc := stats.ByProc[i]
fmt.Fprintf(w, "%d\t%s\t%d\n", proc.PID, proc.Process, proc.Count) errutil.Ignore(fmt.Fprintf(w, "%d\t%s\t%d\n", proc.PID, proc.Process, proc.Count))
} }
} }
} }

View File

@@ -33,11 +33,12 @@ var topCmd = &cobra.Command{
resolver.SetNoCache(effectiveNoCache) resolver.SetNoCache(effectiveNoCache)
opts := tui.Options{ opts := tui.Options{
Theme: theme, Theme: theme,
Interval: topInterval, Interval: topInterval,
ResolveAddrs: resolveAddrs, ResolveAddrs: resolveAddrs,
ResolvePorts: resolvePorts, ResolvePorts: resolvePorts,
NoCache: effectiveNoCache, NoCache: effectiveNoCache,
RememberState: cfg.TUI.RememberState,
} }
// if any filter flag is set, use exclusive mode // if any filter flag is set, use exclusive mode

View File

@@ -18,6 +18,7 @@ import (
"github.com/fatih/color" "github.com/fatih/color"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/karol-broda/snitch/internal/errutil"
"github.com/karol-broda/snitch/internal/tui" "github.com/karol-broda/snitch/internal/tui"
) )
@@ -93,13 +94,13 @@ func runUpgrade(cmd *cobra.Command, args []string) error {
if currentClean == latestClean { if currentClean == latestClean {
green := color.New(color.FgGreen) green := color.New(color.FgGreen)
green.Println(tui.SymbolSuccess + " you are running the latest version") errutil.Println(green, tui.SymbolSuccess+" you are running the latest version")
return nil return nil
} }
if current == "dev" { if current == "dev" {
yellow := color.New(color.FgYellow) yellow := color.New(color.FgYellow)
yellow.Println(tui.SymbolWarning + " you are running a development build") errutil.Println(yellow, tui.SymbolWarning+" you are running a development build")
fmt.Println() fmt.Println()
fmt.Println("use one of the methods below to install a release version:") fmt.Println("use one of the methods below to install a release version:")
fmt.Println() fmt.Println()
@@ -108,7 +109,7 @@ func runUpgrade(cmd *cobra.Command, args []string) error {
} }
green := color.New(color.FgGreen, color.Bold) green := color.New(color.FgGreen, color.Bold)
green.Printf(tui.SymbolSuccess+" update available: %s "+tui.SymbolArrowRight+" %s\n", current, latest) errutil.Printf(green, tui.SymbolSuccess+" update available: %s "+tui.SymbolArrowRight+" %s\n", current, latest)
fmt.Println() fmt.Println()
if !upgradeYes { if !upgradeYes {
@@ -116,8 +117,8 @@ func runUpgrade(cmd *cobra.Command, args []string) error {
fmt.Println() fmt.Println()
faint := color.New(color.Faint) faint := color.New(color.Faint)
cmdStyle := color.New(color.FgCyan) cmdStyle := color.New(color.FgCyan)
faint.Print(" in-place ") errutil.Print(faint, " in-place ")
cmdStyle.Println("snitch upgrade --yes") errutil.Println(cmdStyle, "snitch upgrade --yes")
return nil return nil
} }
@@ -134,17 +135,17 @@ func handleSpecificVersion(current, target string) error {
if isVersionLower(targetClean, firstUpgradeVersion) { if isVersionLower(targetClean, firstUpgradeVersion) {
yellow := color.New(color.FgYellow) yellow := color.New(color.FgYellow)
yellow.Printf(tui.SymbolWarning+" warning: the upgrade command was introduced in v%s\n", firstUpgradeVersion) errutil.Printf(yellow, tui.SymbolWarning+" warning: the upgrade command was introduced in v%s\n", firstUpgradeVersion)
faint := color.New(color.Faint) faint := color.New(color.Faint)
faint.Printf(" version %s does not include this command\n", target) errutil.Printf(faint, " version %s does not include this command\n", target)
faint.Println(" you will need to use other methods to upgrade from that version") errutil.Println(faint, " you will need to use other methods to upgrade from that version")
fmt.Println() fmt.Println()
} }
currentClean := strings.TrimPrefix(current, "v") currentClean := strings.TrimPrefix(current, "v")
if currentClean == targetClean { if currentClean == targetClean {
green := color.New(color.FgGreen) green := color.New(color.FgGreen)
green.Println(tui.SymbolSuccess + " you are already running this version") errutil.Println(green, tui.SymbolSuccess+" you are already running this version")
return nil return nil
} }
@@ -153,15 +154,15 @@ func handleSpecificVersion(current, target string) error {
cmdStyle := color.New(color.FgCyan) cmdStyle := color.New(color.FgCyan)
if isVersionLower(targetClean, currentClean) { if isVersionLower(targetClean, currentClean) {
yellow := color.New(color.FgYellow) yellow := color.New(color.FgYellow)
yellow.Printf(tui.SymbolArrowDown+" this will downgrade from %s to %s\n", current, target) errutil.Printf(yellow, tui.SymbolArrowDown+" this will downgrade from %s to %s\n", current, target)
} else { } else {
green := color.New(color.FgGreen) green := color.New(color.FgGreen)
green.Printf(tui.SymbolArrowUp+" this will upgrade from %s to %s\n", current, target) errutil.Printf(green, tui.SymbolArrowUp+" this will upgrade from %s to %s\n", current, target)
} }
fmt.Println() fmt.Println()
faint.Print("run ") errutil.Print(faint, "run ")
cmdStyle.Printf("snitch upgrade --version %s --yes", target) errutil.Printf(cmdStyle, "snitch upgrade --version %s --yes", target)
faint.Println(" to proceed") errutil.Println(faint, " to proceed")
return nil return nil
} }
@@ -175,20 +176,20 @@ func handleNixUpgrade(current, latest string) error {
currentCommit := extractCommitFromVersion(current) currentCommit := extractCommitFromVersion(current)
dirty := isNixDirty(current) dirty := isNixDirty(current)
faint.Print("current ") errutil.Print(faint, "current ")
version.Print(current) errutil.Print(version, current)
if currentCommit != "" { if currentCommit != "" {
faint.Printf(" (commit %s)", currentCommit) errutil.Printf(faint, " (commit %s)", currentCommit)
} }
fmt.Println() fmt.Println()
faint.Print("latest ") errutil.Print(faint, "latest ")
version.Println(latest) errutil.Println(version, latest)
fmt.Println() fmt.Println()
if dirty { if dirty {
yellow := color.New(color.FgYellow) yellow := color.New(color.FgYellow)
yellow.Println(tui.SymbolWarning + " you are running a dirty nix build (uncommitted changes)") errutil.Println(yellow, tui.SymbolWarning+" you are running a dirty nix build (uncommitted changes)")
fmt.Println() fmt.Println()
printNixUpgradeInstructions() printNixUpgradeInstructions()
return nil return nil
@@ -196,8 +197,8 @@ func handleNixUpgrade(current, latest string) error {
if currentCommit == "" { if currentCommit == "" {
yellow := color.New(color.FgYellow) yellow := color.New(color.FgYellow)
yellow.Println(tui.SymbolWarning + " this is a nix installation") errutil.Println(yellow, tui.SymbolWarning+" this is a nix installation")
faint.Println(" nix store is immutable; use nix commands to upgrade") errutil.Println(faint, " nix store is immutable; use nix commands to upgrade")
fmt.Println() fmt.Println()
printNixUpgradeInstructions() printNixUpgradeInstructions()
return nil return nil
@@ -205,11 +206,11 @@ func handleNixUpgrade(current, latest string) error {
releaseCommit, err := fetchCommitForTag(latest) releaseCommit, err := fetchCommitForTag(latest)
if err != nil { if err != nil {
faint.Printf(" (could not fetch release commit: %v)\n", err) errutil.Printf(faint, " (could not fetch release commit: %v)\n", err)
fmt.Println() fmt.Println()
yellow := color.New(color.FgYellow) yellow := color.New(color.FgYellow)
yellow.Println(tui.SymbolWarning + " this is a nix installation") errutil.Println(yellow, tui.SymbolWarning+" this is a nix installation")
faint.Println(" nix store is immutable; use nix commands to upgrade") errutil.Println(faint, " nix store is immutable; use nix commands to upgrade")
fmt.Println() fmt.Println()
printNixUpgradeInstructions() printNixUpgradeInstructions()
return nil return nil
@@ -222,20 +223,20 @@ func handleNixUpgrade(current, latest string) error {
if strings.HasPrefix(releaseCommit, currentCommit) || strings.HasPrefix(currentCommit, releaseShort) { if strings.HasPrefix(releaseCommit, currentCommit) || strings.HasPrefix(currentCommit, releaseShort) {
green := color.New(color.FgGreen) green := color.New(color.FgGreen)
green.Printf(tui.SymbolSuccess+" you are running %s (commit %s)\n", latest, releaseShort) errutil.Printf(green, tui.SymbolSuccess+" you are running %s (commit %s)\n", latest, releaseShort)
return nil return nil
} }
comparison, err := compareCommits(latest, currentCommit) comparison, err := compareCommits(latest, currentCommit)
if err != nil { if err != nil {
green := color.New(color.FgGreen, color.Bold) green := color.New(color.FgGreen, color.Bold)
green.Printf(tui.SymbolSuccess+" update available: %s "+tui.SymbolArrowRight+" %s\n", currentCommit, latest) errutil.Printf(green, tui.SymbolSuccess+" update available: %s "+tui.SymbolArrowRight+" %s\n", currentCommit, latest)
faint.Printf(" your commit: %s\n", currentCommit) errutil.Printf(faint, " your commit: %s\n", currentCommit)
faint.Printf(" release: %s (%s)\n", releaseShort, latest) errutil.Printf(faint, " release: %s (%s)\n", releaseShort, latest)
fmt.Println() fmt.Println()
yellow := color.New(color.FgYellow) yellow := color.New(color.FgYellow)
yellow.Println(tui.SymbolWarning + " this is a nix installation") errutil.Println(yellow, tui.SymbolWarning+" this is a nix installation")
faint.Println(" nix store is immutable; use nix commands to upgrade") errutil.Println(faint, " nix store is immutable; use nix commands to upgrade")
fmt.Println() fmt.Println()
printNixUpgradeInstructions() printNixUpgradeInstructions()
return nil return nil
@@ -243,30 +244,30 @@ func handleNixUpgrade(current, latest string) error {
if comparison.AheadBy > 0 { if comparison.AheadBy > 0 {
cyan := color.New(color.FgCyan) cyan := color.New(color.FgCyan)
cyan.Printf(tui.SymbolArrowUp+" you are %d commit(s) ahead of %s\n", comparison.AheadBy, latest) errutil.Printf(cyan, tui.SymbolArrowUp+" you are %d commit(s) ahead of %s\n", comparison.AheadBy, latest)
faint.Printf(" your commit: %s\n", currentCommit) errutil.Printf(faint, " your commit: %s\n", currentCommit)
faint.Printf(" release: %s (%s)\n", releaseShort, latest) errutil.Printf(faint, " release: %s (%s)\n", releaseShort, latest)
fmt.Println() fmt.Println()
faint.Println("you are running a newer build than the latest release") errutil.Println(faint, "you are running a newer build than the latest release")
return nil return nil
} }
if comparison.BehindBy > 0 { if comparison.BehindBy > 0 {
green := color.New(color.FgGreen, color.Bold) green := color.New(color.FgGreen, color.Bold)
green.Printf(tui.SymbolSuccess+" update available: %d commit(s) behind %s\n", comparison.BehindBy, latest) errutil.Printf(green, tui.SymbolSuccess+" update available: %d commit(s) behind %s\n", comparison.BehindBy, latest)
faint.Printf(" your commit: %s\n", currentCommit) errutil.Printf(faint, " your commit: %s\n", currentCommit)
faint.Printf(" release: %s (%s)\n", releaseShort, latest) errutil.Printf(faint, " release: %s (%s)\n", releaseShort, latest)
fmt.Println() fmt.Println()
yellow := color.New(color.FgYellow) yellow := color.New(color.FgYellow)
yellow.Println(tui.SymbolWarning + " this is a nix installation") errutil.Println(yellow, tui.SymbolWarning+" this is a nix installation")
faint.Println(" nix store is immutable; use nix commands to upgrade") errutil.Println(faint, " nix store is immutable; use nix commands to upgrade")
fmt.Println() fmt.Println()
printNixUpgradeInstructions() printNixUpgradeInstructions()
return nil return nil
} }
green := color.New(color.FgGreen) green := color.New(color.FgGreen)
green.Printf(tui.SymbolSuccess+" you are running %s (commit %s)\n", latest, releaseShort) errutil.Printf(green, tui.SymbolSuccess+" you are running %s (commit %s)\n", latest, releaseShort)
return nil return nil
} }
@@ -278,22 +279,22 @@ func handleNixSpecificVersion(current, target string) error {
printVersionComparisonTarget(current, target) printVersionComparisonTarget(current, target)
yellow := color.New(color.FgYellow) yellow := color.New(color.FgYellow)
yellow.Println(tui.SymbolWarning + " this is a nix installation") errutil.Println(yellow, tui.SymbolWarning+" this is a nix installation")
faint := color.New(color.Faint) faint := color.New(color.Faint)
faint.Println(" nix store is immutable; in-place upgrades are not supported") errutil.Println(faint, " nix store is immutable; in-place upgrades are not supported")
fmt.Println() fmt.Println()
bold := color.New(color.Bold) bold := color.New(color.Bold)
cmd := color.New(color.FgCyan) cmd := color.New(color.FgCyan)
bold.Println("to install a specific version with nix:") errutil.Println(bold, "to install a specific version with nix:")
fmt.Println() fmt.Println()
faint.Print(" specific ref ") errutil.Print(faint, " specific ref ")
cmd.Printf("nix profile install github:%s/%s/%s\n", repoOwner, repoName, target) errutil.Printf(cmd, "nix profile install github:%s/%s/%s\n", repoOwner, repoName, target)
faint.Print(" latest ") errutil.Print(faint, " latest ")
cmd.Printf("nix profile install github:%s/%s\n", repoOwner, repoName) errutil.Printf(cmd, "nix profile install github:%s/%s\n", repoOwner, repoName)
return nil return nil
} }
@@ -333,7 +334,7 @@ func fetchLatestVersion() (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
defer resp.Body.Close() defer errutil.Close(resp.Body)
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("github api returned status %d", resp.StatusCode) return "", fmt.Errorf("github api returned status %d", resp.StatusCode)
@@ -355,10 +356,10 @@ func printVersionComparison(current, latest string) {
faint := color.New(color.Faint) faint := color.New(color.Faint)
version := color.New(color.FgCyan) version := color.New(color.FgCyan)
faint.Print("current ") errutil.Print(faint, "current ")
version.Println(current) errutil.Println(version, current)
faint.Print("latest ") errutil.Print(faint, "latest ")
version.Println(latest) errutil.Println(version, latest)
fmt.Println() fmt.Println()
} }
@@ -366,10 +367,10 @@ func printVersionComparisonTarget(current, target string) {
faint := color.New(color.Faint) faint := color.New(color.Faint)
version := color.New(color.FgCyan) version := color.New(color.FgCyan)
faint.Print("current ") errutil.Print(faint, "current ")
version.Println(current) errutil.Println(version, current)
faint.Print("target ") errutil.Print(faint, "target ")
version.Println(target) errutil.Println(version, target)
fmt.Println() fmt.Println()
} }
@@ -378,20 +379,20 @@ func printUpgradeInstructions() {
faint := color.New(color.Faint) faint := color.New(color.Faint)
cmd := color.New(color.FgCyan) cmd := color.New(color.FgCyan)
bold.Println("upgrade options:") errutil.Println(bold, "upgrade options:")
fmt.Println() fmt.Println()
faint.Print(" go install ") errutil.Print(faint, " go install ")
cmd.Printf("go install github.com/%s/%s@latest\n", repoOwner, repoName) errutil.Printf(cmd, "go install github.com/%s/%s@latest\n", repoOwner, repoName)
faint.Print(" shell script ") errutil.Print(faint, " shell script ")
cmd.Printf("curl -sSL https://raw.githubusercontent.com/%s/%s/master/install.sh | sh\n", repoOwner, repoName) errutil.Printf(cmd, "curl -sSL https://raw.githubusercontent.com/%s/%s/master/install.sh | sh\n", repoOwner, repoName)
faint.Print(" arch (aur) ") errutil.Print(faint, " arch (aur) ")
cmd.Println("yay -S snitch-bin") errutil.Println(cmd, "yay -S snitch-bin")
faint.Print(" nix ") errutil.Print(faint, " nix ")
cmd.Printf("nix profile upgrade --inputs-from github:%s/%s\n", repoOwner, repoName) errutil.Printf(cmd, "nix profile upgrade --inputs-from github:%s/%s\n", repoOwner, repoName)
} }
func performUpgrade(version string) error { func performUpgrade(version string) error {
@@ -407,7 +408,7 @@ func performUpgrade(version string) error {
if strings.HasPrefix(execPath, "/nix/store/") { if strings.HasPrefix(execPath, "/nix/store/") {
yellow := color.New(color.FgYellow) yellow := color.New(color.FgYellow)
yellow.Println(tui.SymbolWarning + " cannot perform in-place upgrade for nix installation") errutil.Println(yellow, tui.SymbolWarning+" cannot perform in-place upgrade for nix installation")
fmt.Println() fmt.Println()
printNixUpgradeInstructions() printNixUpgradeInstructions()
return nil return nil
@@ -423,15 +424,15 @@ func performUpgrade(version string) error {
faint := color.New(color.Faint) faint := color.New(color.Faint)
cyan := color.New(color.FgCyan) cyan := color.New(color.FgCyan)
faint.Print(tui.SymbolDownload + " downloading ") errutil.Print(faint, tui.SymbolDownload+" downloading ")
cyan.Printf("%s", archiveName) errutil.Printf(cyan, "%s", archiveName)
faint.Println("...") errutil.Println(faint, "...")
resp, err := http.Get(downloadURL) resp, err := http.Get(downloadURL)
if err != nil { if err != nil {
return fmt.Errorf("failed to download: %w", err) return fmt.Errorf("failed to download: %w", err)
} }
defer resp.Body.Close() defer errutil.Close(resp.Body)
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download failed with status %d", resp.StatusCode) return fmt.Errorf("download failed with status %d", resp.StatusCode)
@@ -441,7 +442,7 @@ func performUpgrade(version string) error {
if err != nil { if err != nil {
return fmt.Errorf("failed to create temp directory: %w", err) return fmt.Errorf("failed to create temp directory: %w", err)
} }
defer os.RemoveAll(tmpDir) defer errutil.RemoveAll(tmpDir)
binaryPath, err := extractBinaryFromTarGz(resp.Body, tmpDir) binaryPath, err := extractBinaryFromTarGz(resp.Body, tmpDir)
if err != nil { if err != nil {
@@ -458,14 +459,14 @@ func performUpgrade(version string) error {
yellow := color.New(color.FgYellow) yellow := color.New(color.FgYellow)
cmdStyle := color.New(color.FgCyan) cmdStyle := color.New(color.FgCyan)
yellow.Printf(tui.SymbolWarning+" elevated permissions required to install to %s\n", targetDir) errutil.Printf(yellow, tui.SymbolWarning+" elevated permissions required to install to %s\n", targetDir)
fmt.Println() fmt.Println()
faint.Println("run with sudo or install to a user-writable location:") errutil.Println(faint, "run with sudo or install to a user-writable location:")
fmt.Println() fmt.Println()
faint.Print(" sudo ") errutil.Print(faint, " sudo ")
cmdStyle.Println("sudo snitch upgrade --yes") errutil.Println(cmdStyle, "sudo snitch upgrade --yes")
faint.Print(" custom dir ") errutil.Print(faint, " custom dir ")
cmdStyle.Printf("curl -sSL https://raw.githubusercontent.com/%s/%s/master/install.sh | INSTALL_DIR=~/.local/bin sh\n", errutil.Printf(cmdStyle, "curl -sSL https://raw.githubusercontent.com/%s/%s/master/install.sh | INSTALL_DIR=~/.local/bin sh\n",
repoOwner, repoName) repoOwner, repoName)
return nil return nil
} }
@@ -491,11 +492,11 @@ func performUpgrade(version string) error {
if err := os.Remove(backupPath); err != nil { if err := os.Remove(backupPath); err != nil {
// non-fatal, just warn // non-fatal, just warn
yellow := color.New(color.FgYellow) yellow := color.New(color.FgYellow)
yellow.Fprintf(os.Stderr, tui.SymbolWarning + " warning: failed to remove backup file %s: %v\n", backupPath, err) errutil.Fprintf(yellow, os.Stderr, tui.SymbolWarning+" warning: failed to remove backup file %s: %v\n", backupPath, err)
} }
green := color.New(color.FgGreen, color.Bold) green := color.New(color.FgGreen, color.Bold)
green.Printf(tui.SymbolSuccess + " successfully upgraded to %s\n", version) errutil.Printf(green, tui.SymbolSuccess+" successfully upgraded to %s\n", version)
return nil return nil
} }
@@ -504,7 +505,7 @@ func extractBinaryFromTarGz(r io.Reader, destDir string) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
defer gzr.Close() defer errutil.Close(gzr)
tr := tar.NewReader(gzr) tr := tar.NewReader(gzr)
@@ -534,10 +535,10 @@ func extractBinaryFromTarGz(r io.Reader, destDir string) (string, error) {
} }
if _, err := io.Copy(outFile, tr); err != nil { if _, err := io.Copy(outFile, tr); err != nil {
outFile.Close() errutil.Close(outFile)
return "", err return "", err
} }
outFile.Close() errutil.Close(outFile)
return destPath, nil return destPath, nil
} }
@@ -551,8 +552,8 @@ func isWritable(path string) bool {
if err != nil { if err != nil {
return false return false
} }
f.Close() errutil.Close(f)
os.Remove(testFile) errutil.Remove(testFile)
return true return true
} }
@@ -561,13 +562,13 @@ func copyFile(src, dst string) error {
if err != nil { if err != nil {
return err return err
} }
defer srcFile.Close() defer errutil.Close(srcFile)
dstFile, err := os.Create(dst) dstFile, err := os.Create(dst)
if err != nil { if err != nil {
return err return err
} }
defer dstFile.Close() defer errutil.Close(dstFile)
if _, err := io.Copy(dstFile, srcFile); err != nil { if _, err := io.Copy(dstFile, srcFile); err != nil {
return err return err
@@ -580,7 +581,7 @@ func removeQuarantine(path string) {
cmd := exec.Command("xattr", "-d", "com.apple.quarantine", path) cmd := exec.Command("xattr", "-d", "com.apple.quarantine", path)
if err := cmd.Run(); err == nil { if err := cmd.Run(); err == nil {
faint := color.New(color.Faint) faint := color.New(color.Faint)
faint.Println(" removed macOS quarantine attribute") errutil.Println(faint, " removed macOS quarantine attribute")
} }
} }
@@ -633,7 +634,7 @@ func fetchCommitForTag(tag string) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
defer resp.Body.Close() defer errutil.Close(resp.Body)
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("github api returned status %d", resp.StatusCode) return "", fmt.Errorf("github api returned status %d", resp.StatusCode)
@@ -654,7 +655,7 @@ func compareCommits(base, head string) (*githubCompare, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer resp.Body.Close() defer errutil.Close(resp.Body)
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("github api returned status %d", resp.StatusCode) return nil, fmt.Errorf("github api returned status %d", resp.StatusCode)
@@ -673,16 +674,16 @@ func printNixUpgradeInstructions() {
faint := color.New(color.Faint) faint := color.New(color.Faint)
cmd := color.New(color.FgCyan) cmd := color.New(color.FgCyan)
bold.Println("nix upgrade options:") errutil.Println(bold, "nix upgrade options:")
fmt.Println() fmt.Println()
faint.Print(" flake profile ") errutil.Print(faint, " flake profile ")
cmd.Printf("nix profile install github:%s/%s\n", repoOwner, repoName) errutil.Printf(cmd, "nix profile install github:%s/%s\n", repoOwner, repoName)
faint.Print(" flake update ") errutil.Print(faint, " flake update ")
cmd.Println("nix flake update snitch (in your system/home-manager config)") errutil.Println(cmd, "nix flake update snitch (in your system/home-manager config)")
faint.Print(" rebuild ") errutil.Print(faint, " rebuild ")
cmd.Println("nixos-rebuild switch or home-manager switch") errutil.Println(cmd, "nixos-rebuild switch or home-manager switch")
} }

View File

@@ -6,6 +6,8 @@ import (
"github.com/fatih/color" "github.com/fatih/color"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/karol-broda/snitch/internal/errutil"
) )
var ( var (
@@ -22,20 +24,20 @@ var versionCmd = &cobra.Command{
cyan := color.New(color.FgCyan) cyan := color.New(color.FgCyan)
faint := color.New(color.Faint) faint := color.New(color.Faint)
bold.Print("snitch ") errutil.Print(bold, "snitch ")
cyan.Println(Version) errutil.Println(cyan, Version)
fmt.Println() fmt.Println()
faint.Print(" commit ") errutil.Print(faint, " commit ")
fmt.Println(Commit) fmt.Println(Commit)
faint.Print(" built ") errutil.Print(faint, " built ")
fmt.Println(Date) fmt.Println(Date)
faint.Print(" go ") errutil.Print(faint, " go ")
fmt.Println(runtime.Version()) fmt.Println(runtime.Version())
faint.Print(" os ") errutil.Print(faint, " os ")
fmt.Printf("%s/%s\n", runtime.GOOS, runtime.GOARCH) fmt.Printf("%s/%s\n", runtime.GOOS, runtime.GOARCH)
}, },
} }

View File

@@ -80,11 +80,18 @@
in in
{ {
packages = eachSystem (system: packages = eachSystem (system:
let pkgs = pkgsFor system; in let
{ pkgs = pkgsFor system;
default = mkSnitch pkgs;
snitch = mkSnitch pkgs; snitch = mkSnitch pkgs;
} # containers only available on linux
containers = if pkgs.stdenv.isLinux
then import ./nix/containers.nix { inherit pkgs snitch; }
else { };
in
{
default = snitch;
inherit snitch;
} // containers
); );
devShells = eachSystem (system: devShells = eachSystem (system:
@@ -94,7 +101,7 @@
in in
{ {
default = pkgs.mkShell { default = pkgs.mkShell {
packages = [ go pkgs.git pkgs.vhs ]; packages = [ go pkgs.git pkgs.vhs pkgs.nix-prefetch-docker ];
env.GOTOOLCHAIN = "local"; env.GOTOOLCHAIN = "local";
shellHook = '' shellHook = ''
echo "go toolchain: $(go version)" echo "go toolchain: $(go version)"
@@ -106,5 +113,32 @@
overlays.default = final: _prev: { overlays.default = final: _prev: {
snitch = mkSnitch final; snitch = mkSnitch final;
}; };
homeManagerModules.default = import ./nix/hm-module.nix;
homeManagerModules.snitch = self.homeManagerModules.default;
# alias for flake-parts compatibility
homeModules.default = self.homeManagerModules.default;
homeModules.snitch = self.homeManagerModules.default;
checks = eachSystem (system:
let
pkgs = import nixpkgs {
inherit system;
overlays = [ self.overlays.default ];
};
in
{
# home manager module tests
hm-module = import ./nix/tests/hm-module-test.nix {
inherit pkgs;
lib = pkgs.lib;
hmModule = self.homeManagerModules.default;
};
# package builds correctly
package = self.packages.${system}.default;
}
);
}; };
} }

View File

@@ -37,6 +37,19 @@ static const char* get_username(int uid) {
return pw->pw_name; return pw->pw_name;
} }
// get current working directory for a process
static int get_proc_cwd(int pid, char *path, int pathlen) {
struct proc_vnodepathinfo vpi;
int ret = proc_pidinfo(pid, PROC_PIDVNODEPATHINFO, 0, &vpi, sizeof(vpi));
if (ret <= 0) {
path[0] = '\0';
return -1;
}
strncpy(path, vpi.pvi_cdir.vip_path, pathlen - 1);
path[pathlen - 1] = '\0';
return 0;
}
// socket info extraction - handles the union properly in C // socket info extraction - handles the union properly in C
typedef struct { typedef struct {
int family; int family;
@@ -164,6 +177,7 @@ func listAllPids() ([]int, error) {
func getConnectionsForPid(pid int) ([]Connection, error) { func getConnectionsForPid(pid int) ([]Connection, error) {
procName := getProcessName(pid) procName := getProcessName(pid)
cwd := getProcessCwd(pid)
uid := int(C.get_proc_uid(C.int(pid))) uid := int(C.get_proc_uid(C.int(pid)))
user := "" user := ""
if uid >= 0 { if uid >= 0 {
@@ -198,7 +212,7 @@ func getConnectionsForPid(pid int) ([]Connection, error) {
continue continue
} }
conn, ok := getSocketInfo(pid, int(fdInfo.proc_fd), procName, uid, user) conn, ok := getSocketInfo(pid, int(fdInfo.proc_fd), procName, cwd, uid, user)
if ok { if ok {
connections = append(connections, conn) connections = append(connections, conn)
} }
@@ -207,7 +221,7 @@ func getConnectionsForPid(pid int) ([]Connection, error) {
return connections, nil return connections, nil
} }
func getSocketInfo(pid, fd int, procName string, uid int, user string) (Connection, bool) { func getSocketInfo(pid, fd int, procName, cwd string, uid int, user string) (Connection, bool) {
var info C.socket_info_t var info C.socket_info_t
ret := C.get_socket_info(C.int(pid), C.int(fd), &info) ret := C.get_socket_info(C.int(pid), C.int(fd), &info)
@@ -276,6 +290,7 @@ func getSocketInfo(pid, fd int, procName string, uid int, user string) (Connecti
Rport: int(info.rport), Rport: int(info.rport),
PID: pid, PID: pid,
Process: procName, Process: procName,
Cwd: cwd,
UID: uid, UID: uid,
User: user, User: user,
Interface: guessNetworkInterface(laddr), Interface: guessNetworkInterface(laddr),
@@ -293,6 +308,15 @@ func getProcessName(pid int) string {
return C.GoString(&name[0]) return C.GoString(&name[0])
} }
func getProcessCwd(pid int) string {
var path [1024]C.char
ret := C.get_proc_cwd(C.int(pid), &path[0], 1024)
if ret != 0 {
return ""
}
return C.GoString(&path[0])
}
func ipv4ToString(addr uint32) string { func ipv4ToString(addr uint32) string {
ip := make(net.IP, 4) ip := make(net.IP, 4)
ip[0] = byte(addr) ip[0] = byte(addr)

View File

@@ -14,6 +14,8 @@ import (
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/karol-broda/snitch/internal/errutil"
) )
// set SNITCH_DEBUG_TIMING=1 to enable timing diagnostics // set SNITCH_DEBUG_TIMING=1 to enable timing diagnostics
@@ -123,6 +125,8 @@ func GetAllConnections() ([]Connection, error) {
type processInfo struct { type processInfo struct {
pid int pid int
command string command string
cmdline string
cwd string
uid int uid int
user string user string
} }
@@ -138,7 +142,7 @@ func buildInodeToProcessMap() (map[int64]*processInfo, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer procDir.Close() defer errutil.Close(procDir)
entries, err := procDir.Readdir(-1) entries, err := procDir.Readdir(-1)
if err != nil { if err != nil {
@@ -246,39 +250,50 @@ func scanProcessSockets(pid int) []inodeEntry {
func getProcessInfo(pid int) (*processInfo, error) { func getProcessInfo(pid int) (*processInfo, error) {
info := &processInfo{pid: pid} info := &processInfo{pid: pid}
pidStr := strconv.Itoa(pid)
commPath := filepath.Join("/proc", strconv.Itoa(pid), "comm") commPath := filepath.Join("/proc", pidStr, "comm")
commData, err := os.ReadFile(commPath) commData, err := os.ReadFile(commPath)
if err == nil && len(commData) > 0 { if err == nil && len(commData) > 0 {
info.command = strings.TrimSpace(string(commData)) info.command = strings.TrimSpace(string(commData))
} }
if info.command == "" { cmdlinePath := filepath.Join("/proc", pidStr, "cmdline")
cmdlinePath := filepath.Join("/proc", strconv.Itoa(pid), "cmdline") cmdlineData, err := os.ReadFile(cmdlinePath)
cmdlineData, err := os.ReadFile(cmdlinePath) if err == nil && len(cmdlineData) > 0 {
if err != nil { parts := bytes.Split(cmdlineData, []byte{0})
return nil, err var args []string
} for _, p := range parts {
if len(p) > 0 {
if len(cmdlineData) > 0 { args = append(args, string(p))
parts := bytes.Split(cmdlineData, []byte{0})
if len(parts) > 0 && len(parts[0]) > 0 {
fullPath := string(parts[0])
baseName := filepath.Base(fullPath)
if strings.Contains(baseName, " ") {
baseName = strings.Fields(baseName)[0]
}
info.command = baseName
} }
} }
info.cmdline = strings.Join(args, " ")
if info.command == "" && len(parts) > 0 && len(parts[0]) > 0 {
fullPath := string(parts[0])
baseName := filepath.Base(fullPath)
if strings.Contains(baseName, " ") {
baseName = strings.Fields(baseName)[0]
}
info.command = baseName
}
} else if info.command == "" {
return nil, err
} }
statusPath := filepath.Join("/proc", strconv.Itoa(pid), "status") cwdPath := filepath.Join("/proc", pidStr, "cwd")
cwdLink, err := os.Readlink(cwdPath)
if err == nil {
info.cwd = cwdLink
}
statusPath := filepath.Join("/proc", pidStr, "status")
statusFile, err := os.Open(statusPath) statusFile, err := os.Open(statusPath)
if err != nil { if err != nil {
return info, nil return info, nil
} }
defer statusFile.Close() defer errutil.Close(statusFile)
scanner := bufio.NewScanner(statusFile) scanner := bufio.NewScanner(statusFile)
for scanner.Scan() { for scanner.Scan() {
@@ -304,7 +319,7 @@ func parseProcNet(path, proto string, ipVersion int, inodeMap map[int64]*process
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer file.Close() defer errutil.Close(file)
var connections []Connection var connections []Connection
scanner := bufio.NewScanner(file) scanner := bufio.NewScanner(file)
@@ -359,6 +374,8 @@ func parseProcNet(path, proto string, ipVersion int, inodeMap map[int64]*process
if procInfo, exists := inodeMap[inode]; exists { if procInfo, exists := inodeMap[inode]; exists {
conn.PID = procInfo.pid conn.PID = procInfo.pid
conn.Process = procInfo.command conn.Process = procInfo.command
conn.Cmdline = procInfo.cmdline
conn.Cwd = procInfo.cwd
conn.UID = procInfo.uid conn.UID = procInfo.uid
conn.User = procInfo.user conn.User = procInfo.user
} }
@@ -473,7 +490,7 @@ func GetUnixSockets() ([]Connection, error) {
if err != nil { if err != nil {
return connections, nil return connections, nil
} }
defer file.Close() defer errutil.Close(file)
scanner := bufio.NewScanner(file) scanner := bufio.NewScanner(file)
scanner.Scan() scanner.Scan()

View File

@@ -114,4 +114,60 @@ func BenchmarkBuildInodeMap(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_, _ = buildInodeToProcessMap() _, _ = buildInodeToProcessMap()
} }
}
func TestConnectionHasCmdlineAndCwd(t *testing.T) {
conns, err := GetConnections()
if err != nil {
t.Fatalf("GetConnections() returned an error: %v", err)
}
if len(conns) == 0 {
t.Skip("no connections to test")
}
// find a connection with a PID (owned by some process)
var connWithProcess *Connection
for i := range conns {
if conns[i].PID > 0 {
connWithProcess = &conns[i]
break
}
}
if connWithProcess == nil {
t.Skip("no connections with associated process found")
}
t.Logf("testing connection: pid=%d process=%s", connWithProcess.PID, connWithProcess.Process)
// cmdline and cwd should be populated for connections with PIDs
// note: they might be empty if we don't have permission to read them
if connWithProcess.Cmdline != "" {
t.Logf("cmdline: %s", connWithProcess.Cmdline)
} else {
t.Logf("cmdline is empty (might be permission issue)")
}
if connWithProcess.Cwd != "" {
t.Logf("cwd: %s", connWithProcess.Cwd)
} else {
t.Logf("cwd is empty (might be permission issue)")
}
}
func TestGetProcessInfoPopulatesCmdlineAndCwd(t *testing.T) {
// test that getProcessInfo correctly populates cmdline and cwd for our own process
info, err := getProcessInfo(1) // init process (usually has cwd of /)
if err != nil {
t.Logf("could not get process info for pid 1: %v", err)
t.Skip("skipping - may not have permission")
}
t.Logf("pid 1 info: command=%s cmdline=%s cwd=%s", info.command, info.cmdline, info.cwd)
// at minimum, we should have a command name
if info.command == "" && info.cmdline == "" {
t.Error("expected either command or cmdline to be populated")
}
} }

View File

@@ -128,3 +128,75 @@ func TestSortByTimestamp(t *testing.T) {
} }
} }
func TestSortByRemoteAddr(t *testing.T) {
conns := []Connection{
{Raddr: "192.168.1.100", Rport: 443},
{Raddr: "10.0.0.1", Rport: 80},
{Raddr: "172.16.0.50", Rport: 8080},
}
t.Run("sort by raddr ascending", func(t *testing.T) {
c := make([]Connection, len(conns))
copy(c, conns)
SortConnections(c, SortOptions{Field: SortByRaddr, Direction: SortAsc})
if c[0].Raddr != "10.0.0.1" {
t.Errorf("expected '10.0.0.1' first, got '%s'", c[0].Raddr)
}
if c[1].Raddr != "172.16.0.50" {
t.Errorf("expected '172.16.0.50' second, got '%s'", c[1].Raddr)
}
if c[2].Raddr != "192.168.1.100" {
t.Errorf("expected '192.168.1.100' last, got '%s'", c[2].Raddr)
}
})
t.Run("sort by raddr descending", func(t *testing.T) {
c := make([]Connection, len(conns))
copy(c, conns)
SortConnections(c, SortOptions{Field: SortByRaddr, Direction: SortDesc})
if c[0].Raddr != "192.168.1.100" {
t.Errorf("expected '192.168.1.100' first, got '%s'", c[0].Raddr)
}
})
}
func TestSortByRemotePort(t *testing.T) {
conns := []Connection{
{Raddr: "192.168.1.1", Rport: 443},
{Raddr: "192.168.1.2", Rport: 80},
{Raddr: "192.168.1.3", Rport: 8080},
}
t.Run("sort by rport ascending", func(t *testing.T) {
c := make([]Connection, len(conns))
copy(c, conns)
SortConnections(c, SortOptions{Field: SortByRport, Direction: SortAsc})
if c[0].Rport != 80 {
t.Errorf("expected port 80 first, got %d", c[0].Rport)
}
if c[1].Rport != 443 {
t.Errorf("expected port 443 second, got %d", c[1].Rport)
}
if c[2].Rport != 8080 {
t.Errorf("expected port 8080 last, got %d", c[2].Rport)
}
})
t.Run("sort by rport descending", func(t *testing.T) {
c := make([]Connection, len(conns))
copy(c, conns)
SortConnections(c, SortOptions{Field: SortByRport, Direction: SortDesc})
if c[0].Rport != 8080 {
t.Errorf("expected port 8080 first, got %d", c[0].Rport)
}
})
}

View File

@@ -6,6 +6,8 @@ type Connection struct {
TS time.Time `json:"ts"` TS time.Time `json:"ts"`
PID int `json:"pid"` PID int `json:"pid"`
Process string `json:"process"` Process string `json:"process"`
Cmdline string `json:"cmdline,omitempty"`
Cwd string `json:"cwd,omitempty"`
User string `json:"user"` User string `json:"user"`
UID int `json:"uid"` UID int `json:"uid"`
Proto string `json:"proto"` Proto string `json:"proto"`

View File

@@ -5,6 +5,8 @@ import (
"testing" "testing"
"github.com/fatih/color" "github.com/fatih/color"
"github.com/karol-broda/snitch/internal/errutil"
) )
func TestInit(t *testing.T) { func TestInit(t *testing.T) {
@@ -29,8 +31,8 @@ func TestInit(t *testing.T) {
origTerm := os.Getenv("TERM") origTerm := os.Getenv("TERM")
// Set test env vars // Set test env vars
os.Setenv("NO_COLOR", tc.noColor) errutil.Setenv("NO_COLOR", tc.noColor)
os.Setenv("TERM", tc.term) errutil.Setenv("TERM", tc.term)
Init(tc.mode) Init(tc.mode)
@@ -39,8 +41,8 @@ func TestInit(t *testing.T) {
} }
// Restore original env vars // Restore original env vars
os.Setenv("NO_COLOR", origNoColor) errutil.Setenv("NO_COLOR", origNoColor)
os.Setenv("TERM", origTerm) errutil.Setenv("TERM", origTerm)
}) })
} }
} }

View File

@@ -14,6 +14,12 @@ import (
// Config represents the application configuration // Config represents the application configuration
type Config struct { type Config struct {
Defaults DefaultConfig `mapstructure:"defaults"` Defaults DefaultConfig `mapstructure:"defaults"`
TUI TUIConfig `mapstructure:"tui"`
}
// TUIConfig contains TUI-specific configuration
type TUIConfig struct {
RememberState bool `mapstructure:"remember_state"`
} }
// DefaultConfig contains default values for CLI options // DefaultConfig contains default values for CLI options
@@ -105,6 +111,9 @@ func setDefaults(v *viper.Viper) {
v.SetDefault("defaults.no_headers", false) v.SetDefault("defaults.no_headers", false)
v.SetDefault("defaults.output_format", "table") v.SetDefault("defaults.output_format", "table")
v.SetDefault("defaults.sort_by", "") v.SetDefault("defaults.sort_by", "")
// tui settings
v.SetDefault("tui.remember_state", false)
} }
func handleSpecialEnvVars(v *viper.Viper) { func handleSpecialEnvVars(v *viper.Viper) {
@@ -146,6 +155,9 @@ func Get() *Config {
OutputFormat: "table", OutputFormat: "table",
SortBy: "", SortBy: "",
}, },
TUI: TUIConfig{
RememberState: false,
},
} }
} }
return config return config
@@ -199,6 +211,11 @@ ipv6 = false
no_headers = false no_headers = false
output_format = "table" output_format = "table"
sort_by = "" sort_by = ""
[tui]
# remember view options (filters, sort, resolution) between sessions
# state is saved to $XDG_STATE_HOME/snitch/tui.json
remember_state = false
`, themeList, theme.DefaultTheme) `, themeList, theme.DefaultTheme)
// Ensure directory exists // Ensure directory exists

View File

@@ -0,0 +1,65 @@
package errutil
import (
"io"
"os"
"github.com/fatih/color"
)
func Ignore[T any](val T, _ error) T {
return val
}
func IgnoreErr(_ error) {}
func Close(c io.Closer) {
if c != nil {
_ = c.Close()
}
}
// color.Color wrappers - these discard the (int, error) return values
func Print(c *color.Color, a ...any) {
_, _ = c.Print(a...)
}
func Println(c *color.Color, a ...any) {
_, _ = c.Println(a...)
}
func Printf(c *color.Color, format string, a ...any) {
_, _ = c.Printf(format, a...)
}
func Fprintf(c *color.Color, w io.Writer, format string, a ...any) {
_, _ = c.Fprintf(w, format, a...)
}
// os function wrappers for test cleanup where errors are non-critical
func Setenv(key, value string) {
_ = os.Setenv(key, value)
}
func Unsetenv(key string) {
_ = os.Unsetenv(key)
}
func Remove(name string) {
_ = os.Remove(name)
}
func RemoveAll(path string) {
_ = os.RemoveAll(path)
}
// Flush calls Flush on a tabwriter and discards the error
type Flusher interface {
Flush() error
}
func Flush(f Flusher) {
_ = f.Flush()
}

133
internal/state/state.go Normal file
View File

@@ -0,0 +1,133 @@
package state
import (
"encoding/json"
"os"
"path/filepath"
"sync"
"github.com/karol-broda/snitch/internal/collector"
)
// TUIState holds view options that can be persisted between sessions
type TUIState struct {
ShowTCP bool `json:"show_tcp"`
ShowUDP bool `json:"show_udp"`
ShowListening bool `json:"show_listening"`
ShowEstablished bool `json:"show_established"`
ShowOther bool `json:"show_other"`
SortField collector.SortField `json:"sort_field"`
SortReverse bool `json:"sort_reverse"`
ResolveAddrs bool `json:"resolve_addrs"`
ResolvePorts bool `json:"resolve_ports"`
}
var (
saveMu sync.Mutex
saveChan chan TUIState
once sync.Once
)
// Path returns the XDG-compliant state file path
func Path() string {
stateDir := os.Getenv("XDG_STATE_HOME")
if stateDir == "" {
home, err := os.UserHomeDir()
if err != nil {
return ""
}
stateDir = filepath.Join(home, ".local", "state")
}
return filepath.Join(stateDir, "snitch", "tui.json")
}
// Load reads the TUI state from disk.
// returns nil if state file doesn't exist or can't be read.
func Load() *TUIState {
path := Path()
if path == "" {
return nil
}
data, err := os.ReadFile(path)
if err != nil {
return nil
}
var state TUIState
if err := json.Unmarshal(data, &state); err != nil {
return nil
}
return &state
}
// Save writes the TUI state to disk synchronously.
// creates parent directories if needed.
func Save(state TUIState) error {
path := Path()
if path == "" {
return nil
}
saveMu.Lock()
defer saveMu.Unlock()
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0755); err != nil {
return err
}
data, err := json.MarshalIndent(state, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, data, 0644)
}
// SaveAsync queues a state save to happen in the background.
// only the most recent state is saved if multiple saves are queued.
func SaveAsync(state TUIState) {
once.Do(func() {
saveChan = make(chan TUIState, 1)
go saveWorker()
})
// non-blocking send, replace pending save with newer state
select {
case saveChan <- state:
default:
// channel full, drain and replace
select {
case <-saveChan:
default:
}
select {
case saveChan <- state:
default:
}
}
}
func saveWorker() {
for state := range saveChan {
_ = Save(state)
}
}
// Default returns a TUIState with default values
func Default() TUIState {
return TUIState{
ShowTCP: true,
ShowUDP: true,
ShowListening: true,
ShowEstablished: true,
ShowOther: true,
SortField: collector.SortByLport,
SortReverse: false,
ResolveAddrs: false,
ResolvePorts: false,
}
}

View File

@@ -0,0 +1,236 @@
package state
import (
"os"
"path/filepath"
"testing"
"time"
"github.com/karol-broda/snitch/internal/collector"
)
func TestPath_XDGStateHome(t *testing.T) {
t.Setenv("XDG_STATE_HOME", "/custom/state")
path := Path()
expected := "/custom/state/snitch/tui.json"
if path != expected {
t.Errorf("Path() = %q, want %q", path, expected)
}
}
func TestPath_DefaultFallback(t *testing.T) {
t.Setenv("XDG_STATE_HOME", "")
path := Path()
home, err := os.UserHomeDir()
if err != nil {
t.Skip("cannot determine home directory")
}
expected := filepath.Join(home, ".local", "state", "snitch", "tui.json")
if path != expected {
t.Errorf("Path() = %q, want %q", path, expected)
}
}
func TestDefault(t *testing.T) {
d := Default()
if d.ShowTCP != true {
t.Error("expected ShowTCP to be true")
}
if d.ShowUDP != true {
t.Error("expected ShowUDP to be true")
}
if d.ShowListening != true {
t.Error("expected ShowListening to be true")
}
if d.ShowEstablished != true {
t.Error("expected ShowEstablished to be true")
}
if d.ShowOther != true {
t.Error("expected ShowOther to be true")
}
if d.SortField != collector.SortByLport {
t.Errorf("expected SortField to be %q, got %q", collector.SortByLport, d.SortField)
}
if d.SortReverse != false {
t.Error("expected SortReverse to be false")
}
if d.ResolveAddrs != false {
t.Error("expected ResolveAddrs to be false")
}
if d.ResolvePorts != false {
t.Error("expected ResolvePorts to be false")
}
}
func TestSaveAndLoad(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("XDG_STATE_HOME", tmpDir)
state := TUIState{
ShowTCP: false,
ShowUDP: true,
ShowListening: true,
ShowEstablished: false,
ShowOther: true,
SortField: collector.SortByProcess,
SortReverse: true,
ResolveAddrs: true,
ResolvePorts: false,
}
err := Save(state)
if err != nil {
t.Fatalf("Save() error = %v", err)
}
// verify file was created
path := Path()
if _, err := os.Stat(path); os.IsNotExist(err) {
t.Fatal("expected state file to exist after Save()")
}
loaded := Load()
if loaded == nil {
t.Fatal("Load() returned nil")
}
if loaded.ShowTCP != state.ShowTCP {
t.Errorf("ShowTCP = %v, want %v", loaded.ShowTCP, state.ShowTCP)
}
if loaded.ShowUDP != state.ShowUDP {
t.Errorf("ShowUDP = %v, want %v", loaded.ShowUDP, state.ShowUDP)
}
if loaded.ShowListening != state.ShowListening {
t.Errorf("ShowListening = %v, want %v", loaded.ShowListening, state.ShowListening)
}
if loaded.ShowEstablished != state.ShowEstablished {
t.Errorf("ShowEstablished = %v, want %v", loaded.ShowEstablished, state.ShowEstablished)
}
if loaded.ShowOther != state.ShowOther {
t.Errorf("ShowOther = %v, want %v", loaded.ShowOther, state.ShowOther)
}
if loaded.SortField != state.SortField {
t.Errorf("SortField = %v, want %v", loaded.SortField, state.SortField)
}
if loaded.SortReverse != state.SortReverse {
t.Errorf("SortReverse = %v, want %v", loaded.SortReverse, state.SortReverse)
}
if loaded.ResolveAddrs != state.ResolveAddrs {
t.Errorf("ResolveAddrs = %v, want %v", loaded.ResolveAddrs, state.ResolveAddrs)
}
if loaded.ResolvePorts != state.ResolvePorts {
t.Errorf("ResolvePorts = %v, want %v", loaded.ResolvePorts, state.ResolvePorts)
}
}
func TestLoad_NonExistent(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("XDG_STATE_HOME", tmpDir)
loaded := Load()
if loaded != nil {
t.Error("expected Load() to return nil for non-existent file")
}
}
func TestLoad_InvalidJSON(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("XDG_STATE_HOME", tmpDir)
// create directory and invalid json file
stateDir := filepath.Join(tmpDir, "snitch")
if err := os.MkdirAll(stateDir, 0755); err != nil {
t.Fatal(err)
}
stateFile := filepath.Join(stateDir, "tui.json")
if err := os.WriteFile(stateFile, []byte("not valid json"), 0644); err != nil {
t.Fatal(err)
}
loaded := Load()
if loaded != nil {
t.Error("expected Load() to return nil for invalid JSON")
}
}
func TestSave_CreatesDirectories(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("XDG_STATE_HOME", tmpDir)
// snitch directory should not exist yet
snitchDir := filepath.Join(tmpDir, "snitch")
if _, err := os.Stat(snitchDir); err == nil {
t.Fatal("expected snitch directory to not exist initially")
}
err := Save(Default())
if err != nil {
t.Fatalf("Save() error = %v", err)
}
// directory should now exist
if _, err := os.Stat(snitchDir); os.IsNotExist(err) {
t.Error("expected Save() to create parent directories")
}
}
func TestSaveAsync(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("XDG_STATE_HOME", tmpDir)
state := TUIState{
ShowTCP: false,
SortField: collector.SortByPID,
}
SaveAsync(state)
// wait for background save with timeout
deadline := time.Now().Add(100 * time.Millisecond)
for time.Now().Before(deadline) {
if loaded := Load(); loaded != nil {
return
}
time.Sleep(5 * time.Millisecond)
}
t.Log("SaveAsync may not have completed in time (non-fatal in CI)")
}
func TestTUIState_JSONRoundtrip(t *testing.T) {
// verify all sort fields serialize correctly
sortFields := []collector.SortField{
collector.SortByLport,
collector.SortByProcess,
collector.SortByPID,
collector.SortByState,
collector.SortByProto,
}
tmpDir := t.TempDir()
t.Setenv("XDG_STATE_HOME", tmpDir)
for _, sf := range sortFields {
state := TUIState{
ShowTCP: true,
SortField: sf,
}
if err := Save(state); err != nil {
t.Fatalf("Save() error for %q: %v", sf, err)
}
loaded := Load()
if loaded == nil {
t.Fatalf("Load() returned nil for %q", sf)
}
if loaded.SortField != sf {
t.Errorf("SortField roundtrip failed: got %q, want %q", loaded.SortField, sf)
}
}
}

View File

@@ -6,6 +6,7 @@ import (
"testing" "testing"
"github.com/karol-broda/snitch/internal/collector" "github.com/karol-broda/snitch/internal/collector"
"github.com/karol-broda/snitch/internal/errutil"
) )
// TestCollector wraps MockCollector for use in tests // TestCollector wraps MockCollector for use in tests
@@ -47,13 +48,13 @@ func SetupTestEnvironment(t *testing.T) (string, func()) {
oldConfig := os.Getenv("SNITCH_CONFIG") oldConfig := os.Getenv("SNITCH_CONFIG")
oldNoColor := os.Getenv("SNITCH_NO_COLOR") oldNoColor := os.Getenv("SNITCH_NO_COLOR")
os.Setenv("SNITCH_NO_COLOR", "1") // Disable colors in tests errutil.Setenv("SNITCH_NO_COLOR", "1")
// Cleanup function // Cleanup function
cleanup := func() { cleanup := func() {
os.RemoveAll(tempDir) errutil.RemoveAll(tempDir)
os.Setenv("SNITCH_CONFIG", oldConfig) errutil.Setenv("SNITCH_CONFIG", oldConfig)
os.Setenv("SNITCH_NO_COLOR", oldNoColor) errutil.Setenv("SNITCH_NO_COLOR", oldNoColor)
} }
return tempDir, cleanup return tempDir, cleanup
@@ -192,8 +193,8 @@ func (oc *OutputCapture) Stop() (string, string, error) {
os.Stderr = oc.oldStderr os.Stderr = oc.oldStderr
// Close files // Close files
oc.stdout.Close() errutil.Close(oc.stdout)
oc.stderr.Close() errutil.Close(oc.stderr)
// Read captured content // Read captured content
stdoutContent, err := os.ReadFile(oc.stdoutFile) stdoutContent, err := os.ReadFile(oc.stdoutFile)
@@ -207,9 +208,9 @@ func (oc *OutputCapture) Stop() (string, string, error) {
} }
// Cleanup // Cleanup
os.Remove(oc.stdoutFile) errutil.Remove(oc.stdoutFile)
os.Remove(oc.stderrFile) errutil.Remove(oc.stderrFile)
os.Remove(filepath.Dir(oc.stdoutFile)) errutil.Remove(filepath.Dir(oc.stdoutFile))
return string(stdoutContent), string(stderrContent), nil return string(stdoutContent), string(stderrContent), nil
} }

View File

@@ -38,6 +38,10 @@ func sortFieldLabel(f collector.SortField) string {
return "state" return "state"
case collector.SortByProto: case collector.SortByProto:
return "proto" return "proto"
case collector.SortByRaddr:
return "raddr"
case collector.SortByRport:
return "rport"
default: default:
return "port" return "port"
} }

View File

@@ -2,10 +2,12 @@ package tui
import ( import (
"fmt" "fmt"
"github.com/karol-broda/snitch/internal/collector" "strings"
"time" "time"
tea "github.com/charmbracelet/bubbletea" tea "github.com/charmbracelet/bubbletea"
"github.com/karol-broda/snitch/internal/collector"
) )
func (m model) handleKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { func (m model) handleKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
@@ -14,6 +16,11 @@ func (m model) handleKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
return m.handleSearchKey(msg) return m.handleSearchKey(msg)
} }
// export modal captures all input
if m.showExportModal {
return m.handleExportKey(msg)
}
// kill confirmation dialog // kill confirmation dialog
if m.showKillConfirm { if m.showKillConfirm {
return m.handleKillConfirmKey(msg) return m.handleKillConfirmKey(msg)
@@ -52,6 +59,82 @@ func (m model) handleSearchKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
return m, nil return m, nil
} }
func (m model) handleExportKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
switch msg.String() {
case "esc":
m.showExportModal = false
m.exportFilename = ""
m.exportFormat = ""
m.exportError = ""
case "tab":
// toggle format
if m.exportFormat == "tsv" {
m.exportFormat = "csv"
} else {
m.exportFormat = "tsv"
}
m.exportError = ""
case "enter":
// build final filename with extension
filename := m.exportFilename
if filename == "" {
filename = "connections"
}
ext := ".csv"
if m.exportFormat == "tsv" {
ext = ".tsv"
}
// only add extension if not already present
if !strings.HasSuffix(strings.ToLower(filename), ".csv") &&
!strings.HasSuffix(strings.ToLower(filename), ".tsv") {
filename = filename + ext
}
m.exportFilename = filename
err := m.exportConnections()
if err != nil {
m.exportError = err.Error()
return m, nil
}
visible := m.visibleConnections()
m.statusMessage = fmt.Sprintf("%s exported %d connections to %s", SymbolSuccess, len(visible), filename)
m.statusExpiry = time.Now().Add(3 * time.Second)
m.showExportModal = false
m.exportFilename = ""
m.exportFormat = ""
m.exportError = ""
return m, clearStatusAfter(3 * time.Second)
case "backspace":
if len(m.exportFilename) > 0 {
m.exportFilename = m.exportFilename[:len(m.exportFilename)-1]
}
m.exportError = ""
default:
// only accept valid filename characters
char := msg.String()
if len(char) == 1 && isValidFilenameChar(char[0]) {
m.exportFilename += char
m.exportError = ""
}
}
return m, nil
}
func isValidFilenameChar(c byte) bool {
// allow alphanumeric, dash, underscore, dot
return (c >= 'a' && c <= 'z') ||
(c >= 'A' && c <= 'Z') ||
(c >= '0' && c <= '9') ||
c == '-' || c == '_' || c == '.'
}
func (m model) handleDetailKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { func (m model) handleDetailKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
switch msg.String() { switch msg.String() {
case "esc", "enter", "q": case "esc", "enter", "q":
@@ -118,37 +201,52 @@ func (m model) handleNormalKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
case "t": case "t":
m.showTCP = !m.showTCP m.showTCP = !m.showTCP
m.clampCursor() m.clampCursor()
m.saveState()
case "u": case "u":
m.showUDP = !m.showUDP m.showUDP = !m.showUDP
m.clampCursor() m.clampCursor()
m.saveState()
case "l": case "l":
m.showListening = !m.showListening m.showListening = !m.showListening
m.clampCursor() m.clampCursor()
m.saveState()
case "e": case "e":
m.showEstablished = !m.showEstablished m.showEstablished = !m.showEstablished
m.clampCursor() m.clampCursor()
m.saveState()
case "o": case "o":
m.showOther = !m.showOther m.showOther = !m.showOther
m.clampCursor() m.clampCursor()
m.saveState()
case "a": case "a":
m.showTCP = true m.showTCP = true
m.showUDP = true m.showUDP = true
m.showListening = true m.showListening = true
m.showEstablished = true m.showEstablished = true
m.showOther = true m.showOther = true
m.saveState()
// sorting // sorting
case "s": case "s":
m.cycleSort() m.cycleSort()
m.saveState()
case "S": case "S":
m.sortReverse = !m.sortReverse m.sortReverse = !m.sortReverse
m.applySorting() m.applySorting()
m.saveState()
// search // search
case "/": case "/":
m.searchActive = true m.searchActive = true
m.searchQuery = "" m.searchQuery = ""
// export
case "x":
m.showExportModal = true
m.exportFilename = ""
m.exportFormat = "csv"
m.exportError = ""
// actions // actions
case "enter", " ": case "enter", " ":
visible := m.visibleConnections() visible := m.visibleConnections()
@@ -220,6 +318,7 @@ func (m model) handleNormalKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
m.statusMessage = "address resolution: off" m.statusMessage = "address resolution: off"
} }
m.statusExpiry = time.Now().Add(2 * time.Second) m.statusExpiry = time.Now().Add(2 * time.Second)
m.saveState()
return m, clearStatusAfter(2 * time.Second) return m, clearStatusAfter(2 * time.Second)
// toggle port resolution // toggle port resolution
@@ -231,6 +330,7 @@ func (m model) handleNormalKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
m.statusMessage = "port resolution: off" m.statusMessage = "port resolution: off"
} }
m.statusExpiry = time.Now().Add(2 * time.Second) m.statusExpiry = time.Now().Add(2 * time.Second)
m.saveState()
return m, clearStatusAfter(2 * time.Second) return m, clearStatusAfter(2 * time.Second)
} }
@@ -266,6 +366,8 @@ func (m *model) cycleSort() {
collector.SortByPID, collector.SortByPID,
collector.SortByState, collector.SortByState,
collector.SortByProto, collector.SortByProto,
collector.SortByRaddr,
collector.SortByRport,
} }
for i, f := range fields { for i, f := range fields {

View File

@@ -2,11 +2,16 @@ package tui
import ( import (
"fmt" "fmt"
"github.com/karol-broda/snitch/internal/collector" "os"
"github.com/karol-broda/snitch/internal/theme" "strconv"
"strings"
"time" "time"
tea "github.com/charmbracelet/bubbletea" tea "github.com/charmbracelet/bubbletea"
"github.com/karol-broda/snitch/internal/collector"
"github.com/karol-broda/snitch/internal/state"
"github.com/karol-broda/snitch/internal/theme"
) )
type model struct { type model struct {
@@ -51,20 +56,30 @@ type model struct {
// status message (temporary feedback) // status message (temporary feedback)
statusMessage string statusMessage string
statusExpiry time.Time statusExpiry time.Time
// export modal
showExportModal bool
exportFilename string
exportFormat string // "csv" or "tsv"
exportError string
// state persistence
rememberState bool
} }
type Options struct { type Options struct {
Theme string Theme string
Interval time.Duration Interval time.Duration
TCP bool TCP bool
UDP bool UDP bool
Listening bool Listening bool
Established bool Established bool
Other bool Other bool
FilterSet bool // true if user specified any filter flags FilterSet bool // true if user specified any filter flags
ResolveAddrs bool // when true, resolve IP addresses to hostnames ResolveAddrs bool // when true, resolve IP addresses to hostnames
ResolvePorts bool // when true, resolve port numbers to service names ResolvePorts bool // when true, resolve port numbers to service names
NoCache bool // when true, disable DNS caching NoCache bool // when true, disable DNS caching
RememberState bool // when true, persist view options between sessions
} }
func New(opts Options) model { func New(opts Options) model {
@@ -79,8 +94,27 @@ func New(opts Options) model {
showListening := true showListening := true
showEstablished := true showEstablished := true
showOther := true showOther := true
sortField := collector.SortByLport
sortReverse := false
resolveAddrs := opts.ResolveAddrs
resolvePorts := opts.ResolvePorts
// if user specified filters, use those instead // load saved state if enabled and no CLI filter flags were specified
if opts.RememberState && !opts.FilterSet {
if saved := state.Load(); saved != nil {
showTCP = saved.ShowTCP
showUDP = saved.ShowUDP
showListening = saved.ShowListening
showEstablished = saved.ShowEstablished
showOther = saved.ShowOther
sortField = saved.SortField
sortReverse = saved.SortReverse
resolveAddrs = saved.ResolveAddrs
resolvePorts = saved.ResolvePorts
}
}
// if user specified filters, use those instead (CLI flags take precedence)
if opts.FilterSet { if opts.FilterSet {
showTCP = opts.TCP showTCP = opts.TCP
showUDP = opts.UDP showUDP = opts.UDP
@@ -108,13 +142,15 @@ func New(opts Options) model {
showListening: showListening, showListening: showListening,
showEstablished: showEstablished, showEstablished: showEstablished,
showOther: showOther, showOther: showOther,
sortField: collector.SortByLport, sortField: sortField,
resolveAddrs: opts.ResolveAddrs, sortReverse: sortReverse,
resolvePorts: opts.ResolvePorts, resolveAddrs: resolveAddrs,
resolvePorts: resolvePorts,
theme: theme.GetTheme(opts.Theme), theme: theme.GetTheme(opts.Theme),
interval: interval, interval: interval,
lastRefresh: time.Now(), lastRefresh: time.Now(),
watchedPIDs: make(map[int]bool), watchedPIDs: make(map[int]bool),
rememberState: opts.RememberState,
} }
} }
@@ -187,6 +223,11 @@ func (m model) View() string {
return m.overlayModal(main, m.renderKillModal()) return m.overlayModal(main, m.renderKillModal())
} }
// overlay export modal on top of main view
if m.showExportModal {
return m.overlayModal(main, m.renderExportModal())
}
return main return main
} }
@@ -262,12 +303,19 @@ func (m model) matchesFilters(c collector.Connection) bool {
} }
func (m model) matchesSearch(c collector.Connection) bool { func (m model) matchesSearch(c collector.Connection) bool {
lportStr := strconv.Itoa(c.Lport)
rportStr := strconv.Itoa(c.Rport)
pidStr := strconv.Itoa(c.PID)
return containsIgnoreCase(c.Process, m.searchQuery) || return containsIgnoreCase(c.Process, m.searchQuery) ||
containsIgnoreCase(c.Laddr, m.searchQuery) || containsIgnoreCase(c.Laddr, m.searchQuery) ||
containsIgnoreCase(c.Raddr, m.searchQuery) || containsIgnoreCase(c.Raddr, m.searchQuery) ||
containsIgnoreCase(c.User, m.searchQuery) || containsIgnoreCase(c.User, m.searchQuery) ||
containsIgnoreCase(c.Proto, m.searchQuery) || containsIgnoreCase(c.Proto, m.searchQuery) ||
containsIgnoreCase(c.State, m.searchQuery) containsIgnoreCase(c.State, m.searchQuery) ||
containsIgnoreCase(lportStr, m.searchQuery) ||
containsIgnoreCase(rportStr, m.searchQuery) ||
containsIgnoreCase(pidStr, m.searchQuery)
} }
func (m model) isWatched(pid int) bool { func (m model) isWatched(pid int) bool {
@@ -291,3 +339,84 @@ func (m *model) toggleWatch(pid int) {
func (m model) watchedCount() int { func (m model) watchedCount() int {
return len(m.watchedPIDs) return len(m.watchedPIDs)
} }
// currentState returns the current view options as a TUIState for persistence
func (m model) currentState() state.TUIState {
return state.TUIState{
ShowTCP: m.showTCP,
ShowUDP: m.showUDP,
ShowListening: m.showListening,
ShowEstablished: m.showEstablished,
ShowOther: m.showOther,
SortField: m.sortField,
SortReverse: m.sortReverse,
ResolveAddrs: m.resolveAddrs,
ResolvePorts: m.resolvePorts,
}
}
// saveState persists current view options in the background
func (m model) saveState() {
if m.rememberState {
state.SaveAsync(m.currentState())
}
}
// exportConnections writes visible connections to a file in csv or tsv format
func (m model) exportConnections() error {
visible := m.visibleConnections()
if len(visible) == 0 {
return fmt.Errorf("no connections to export")
}
file, err := os.Create(m.exportFilename)
if err != nil {
return err
}
defer func() { _ = file.Close() }()
// determine delimiter from format selection or filename
delimiter := ","
if m.exportFormat == "tsv" || strings.HasSuffix(strings.ToLower(m.exportFilename), ".tsv") {
delimiter = "\t"
}
header := []string{"PID", "PROCESS", "USER", "PROTO", "STATE", "LADDR", "LPORT", "RADDR", "RPORT"}
_, err = file.WriteString(strings.Join(header, delimiter) + "\n")
if err != nil {
return err
}
for _, c := range visible {
// escape fields that might contain delimiter
process := escapeField(c.Process, delimiter)
user := escapeField(c.User, delimiter)
row := []string{
strconv.Itoa(c.PID),
process,
user,
c.Proto,
c.State,
c.Laddr,
strconv.Itoa(c.Lport),
c.Raddr,
strconv.Itoa(c.Rport),
}
_, err = file.WriteString(strings.Join(row, delimiter) + "\n")
if err != nil {
return err
}
}
return nil
}
// escapeField quotes a field if it contains the delimiter or quotes
func escapeField(s, delimiter string) string {
if strings.Contains(s, delimiter) || strings.Contains(s, "\"") || strings.Contains(s, "\n") {
return "\"" + strings.ReplaceAll(s, "\"", "\"\"") + "\""
}
return s
}

View File

@@ -1,12 +1,15 @@
package tui package tui
import ( import (
"github.com/karol-broda/snitch/internal/collector" "os"
"path/filepath"
"strings"
"testing" "testing"
"time" "time"
tea "github.com/charmbracelet/bubbletea" tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/x/exp/teatest" "github.com/charmbracelet/x/exp/teatest"
"github.com/karol-broda/snitch/internal/collector"
) )
func TestTUI_InitialState(t *testing.T) { func TestTUI_InitialState(t *testing.T) {
@@ -430,3 +433,346 @@ func TestTUI_FormatRemoteHelper(t *testing.T) {
} }
} }
func TestTUI_MatchesSearchPort(t *testing.T) {
m := New(Options{Theme: "dark"})
tests := []struct {
name string
searchQuery string
conn collector.Connection
expected bool
}{
{
name: "matches local port",
searchQuery: "3000",
conn: collector.Connection{Lport: 3000},
expected: true,
},
{
name: "matches remote port",
searchQuery: "443",
conn: collector.Connection{Rport: 443},
expected: true,
},
{
name: "matches pid",
searchQuery: "1234",
conn: collector.Connection{PID: 1234},
expected: true,
},
{
name: "partial port match",
searchQuery: "80",
conn: collector.Connection{Lport: 8080},
expected: true,
},
{
name: "no match",
searchQuery: "9999",
conn: collector.Connection{Lport: 80, Rport: 443, PID: 1234},
expected: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
m.searchQuery = tc.searchQuery
result := m.matchesSearch(tc.conn)
if result != tc.expected {
t.Errorf("matchesSearch() = %v, want %v", result, tc.expected)
}
})
}
}
func TestTUI_SortCycleIncludesRemote(t *testing.T) {
m := New(Options{Theme: "dark", Interval: time.Hour})
// start at default (Lport)
if m.sortField != collector.SortByLport {
t.Fatalf("expected initial sort field to be lport, got %v", m.sortField)
}
// cycle through all fields and verify raddr and rport are included
foundRaddr := false
foundRport := false
seenFields := make(map[collector.SortField]bool)
for i := 0; i < 10; i++ {
m.cycleSort()
seenFields[m.sortField] = true
if m.sortField == collector.SortByRaddr {
foundRaddr = true
}
if m.sortField == collector.SortByRport {
foundRport = true
}
if foundRaddr && foundRport {
break
}
}
if !foundRaddr {
t.Error("expected sort cycle to include SortByRaddr")
}
if !foundRport {
t.Error("expected sort cycle to include SortByRport")
}
}
func TestTUI_ExportModal(t *testing.T) {
m := New(Options{Theme: "dark", Interval: time.Hour})
m.width = 120
m.height = 40
// initially export modal should not be shown
if m.showExportModal {
t.Fatal("expected showExportModal to be false initially")
}
// press 'x' to open export modal
newModel, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'x'}})
m = newModel.(model)
if !m.showExportModal {
t.Error("expected showExportModal to be true after pressing 'x'")
}
// type filename
for _, c := range "test.csv" {
newModel, _ = m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{c}})
m = newModel.(model)
}
if m.exportFilename != "test.csv" {
t.Errorf("expected exportFilename to be 'test.csv', got '%s'", m.exportFilename)
}
// escape should close modal
newModel, _ = m.Update(tea.KeyMsg{Type: tea.KeyEsc})
m = newModel.(model)
if m.showExportModal {
t.Error("expected showExportModal to be false after escape")
}
if m.exportFilename != "" {
t.Error("expected exportFilename to be cleared after escape")
}
}
func TestTUI_ExportModalDefaultFilename(t *testing.T) {
m := New(Options{Theme: "dark", Interval: time.Hour})
m.width = 120
m.height = 40
// add test data
m.connections = []collector.Connection{
{PID: 1234, Process: "nginx", Proto: "tcp", State: "LISTEN", Laddr: "0.0.0.0", Lport: 80},
}
// open export modal
newModel, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'x'}})
m = newModel.(model)
// render export modal should show default filename hint
view := m.View()
if view == "" {
t.Error("expected non-empty view with export modal")
}
}
func TestTUI_ExportModalBackspace(t *testing.T) {
m := New(Options{Theme: "dark", Interval: time.Hour})
m.width = 120
m.height = 40
// open export modal
newModel, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'x'}})
m = newModel.(model)
// type filename
for _, c := range "test.csv" {
newModel, _ = m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{c}})
m = newModel.(model)
}
// backspace should remove last character
newModel, _ = m.Update(tea.KeyMsg{Type: tea.KeyBackspace})
m = newModel.(model)
if m.exportFilename != "test.cs" {
t.Errorf("expected 'test.cs' after backspace, got '%s'", m.exportFilename)
}
}
func TestTUI_ExportConnectionsCSV(t *testing.T) {
m := New(Options{Theme: "dark", Interval: time.Hour})
m.connections = []collector.Connection{
{PID: 1234, Process: "nginx", User: "www-data", Proto: "tcp", State: "LISTEN", Laddr: "0.0.0.0", Lport: 80, Raddr: "*", Rport: 0},
{PID: 5678, Process: "node", User: "node", Proto: "tcp", State: "ESTABLISHED", Laddr: "192.168.1.1", Lport: 3000, Raddr: "10.0.0.1", Rport: 443},
}
tmpDir := t.TempDir()
csvPath := filepath.Join(tmpDir, "test_export.csv")
m.exportFilename = csvPath
err := m.exportConnections()
if err != nil {
t.Fatalf("exportConnections() failed: %v", err)
}
content, err := os.ReadFile(csvPath)
if err != nil {
t.Fatalf("failed to read exported file: %v", err)
}
lines := strings.Split(strings.TrimSpace(string(content)), "\n")
if len(lines) != 3 {
t.Errorf("expected 3 lines (header + 2 data), got %d", len(lines))
}
if !strings.Contains(lines[0], "PID") || !strings.Contains(lines[0], "PROCESS") {
t.Error("header line should contain PID and PROCESS")
}
if !strings.Contains(lines[1], "nginx") || !strings.Contains(lines[1], "1234") {
t.Error("first data line should contain nginx and 1234")
}
if !strings.Contains(lines[2], "node") || !strings.Contains(lines[2], "5678") {
t.Error("second data line should contain node and 5678")
}
}
func TestTUI_ExportConnectionsTSV(t *testing.T) {
m := New(Options{Theme: "dark", Interval: time.Hour})
m.connections = []collector.Connection{
{PID: 1234, Process: "nginx", User: "www-data", Proto: "tcp", State: "LISTEN", Laddr: "0.0.0.0", Lport: 80, Raddr: "*", Rport: 0},
}
tmpDir := t.TempDir()
tsvPath := filepath.Join(tmpDir, "test_export.tsv")
m.exportFilename = tsvPath
err := m.exportConnections()
if err != nil {
t.Fatalf("exportConnections() failed: %v", err)
}
content, err := os.ReadFile(tsvPath)
if err != nil {
t.Fatalf("failed to read exported file: %v", err)
}
lines := strings.Split(strings.TrimSpace(string(content)), "\n")
// TSV should use tabs
if !strings.Contains(lines[0], "\t") {
t.Error("TSV file should use tabs as delimiters")
}
// CSV delimiter should not be present between fields
fields := strings.Split(lines[1], "\t")
if len(fields) < 9 {
t.Errorf("expected at least 9 tab-separated fields, got %d", len(fields))
}
}
func TestTUI_ExportWithFilters(t *testing.T) {
m := New(Options{Theme: "dark", Interval: time.Hour})
m.showTCP = true
m.showUDP = false
m.connections = []collector.Connection{
{PID: 1, Process: "tcp_proc", Proto: "tcp", State: "LISTEN", Laddr: "0.0.0.0", Lport: 80},
{PID: 2, Process: "udp_proc", Proto: "udp", State: "LISTEN", Laddr: "0.0.0.0", Lport: 53},
}
tmpDir := t.TempDir()
csvPath := filepath.Join(tmpDir, "filtered_export.csv")
m.exportFilename = csvPath
err := m.exportConnections()
if err != nil {
t.Fatalf("exportConnections() failed: %v", err)
}
content, err := os.ReadFile(csvPath)
if err != nil {
t.Fatalf("failed to read exported file: %v", err)
}
lines := strings.Split(strings.TrimSpace(string(content)), "\n")
// should only have header + 1 TCP connection (UDP filtered out)
if len(lines) != 2 {
t.Errorf("expected 2 lines (header + 1 TCP), got %d", len(lines))
}
if strings.Contains(string(content), "udp_proc") {
t.Error("UDP connection should not be exported when UDP filter is off")
}
}
func TestTUI_ExportFormatToggle(t *testing.T) {
m := New(Options{Theme: "dark", Interval: time.Hour})
m.width = 120
m.height = 40
// open export modal
newModel, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'x'}})
m = newModel.(model)
// default format should be csv
if m.exportFormat != "csv" {
t.Errorf("expected default format 'csv', got '%s'", m.exportFormat)
}
// tab should toggle to tsv
newModel, _ = m.Update(tea.KeyMsg{Type: tea.KeyTab})
m = newModel.(model)
if m.exportFormat != "tsv" {
t.Errorf("expected format 'tsv' after tab, got '%s'", m.exportFormat)
}
// tab again should toggle back to csv
newModel, _ = m.Update(tea.KeyMsg{Type: tea.KeyTab})
m = newModel.(model)
if m.exportFormat != "csv" {
t.Errorf("expected format 'csv' after second tab, got '%s'", m.exportFormat)
}
}
func TestTUI_ExportModalRenderWithStats(t *testing.T) {
m := New(Options{Theme: "dark", Interval: time.Hour})
m.width = 120
m.height = 40
m.connections = []collector.Connection{
{PID: 1, Process: "nginx", Proto: "tcp", State: "LISTEN", Laddr: "0.0.0.0", Lport: 80},
{PID: 2, Process: "postgres", Proto: "tcp", State: "LISTEN", Laddr: "127.0.0.1", Lport: 5432},
{PID: 3, Process: "node", Proto: "tcp", State: "ESTABLISHED", Laddr: "192.168.1.1", Lport: 3000},
}
m.showExportModal = true
m.exportFormat = "csv"
view := m.View()
// modal should contain summary info
if !strings.Contains(view, "3") {
t.Error("modal should show connection count")
}
// modal should show format options
if !strings.Contains(view, "CSV") || !strings.Contains(view, "TSV") {
t.Error("modal should show format options")
}
}

View File

@@ -33,6 +33,8 @@ const (
BoxCross = string('\u253C') // light vertical and horizontal BoxCross = string('\u253C') // light vertical and horizontal
// misc // misc
SymbolDash = string('\u2013') // en dash SymbolDash = string('\u2013') // en dash
SymbolExport = string('\u21E5') // rightwards arrow to bar
SymbolPrompt = string('\u276F') // heavy right-pointing angle quotation mark ornament
) )

View File

@@ -203,7 +203,7 @@ func (m model) renderStatusLine() string {
return " " + m.theme.Styles.Warning.Render(m.statusMessage) return " " + m.theme.Styles.Warning.Render(m.statusMessage)
} }
left := " " + m.theme.Styles.Normal.Render("t/u proto l/e/o state n/N dns w watch K kill s sort / search ? help q quit") left := " " + m.theme.Styles.Normal.Render("t/u proto l/e/o state n/N dns w watch K kill s sort / search x export ? help q quit")
// show watched count if any // show watched count if any
if m.watchedCount() > 0 { if m.watchedCount() > 0 {
@@ -271,6 +271,7 @@ func (m model) renderHelp() string {
other other
───── ─────
/ search / search
x export to csv/tsv (enter filename)
r refresh now r refresh now
q quit q quit
@@ -301,6 +302,8 @@ func (m model) renderDetail() string {
value string value string
}{ }{
{"process", c.Process}, {"process", c.Process},
{"cmdline", c.Cmdline},
{"cwd", c.Cwd},
{"pid", fmt.Sprintf("%d", c.PID)}, {"pid", fmt.Sprintf("%d", c.PID)},
{"user", c.User}, {"user", c.User},
{"protocol", c.Proto}, {"protocol", c.Proto},
@@ -368,6 +371,119 @@ func (m model) renderKillModal() string {
return strings.Join(lines, "\n") return strings.Join(lines, "\n")
} }
func (m model) renderExportModal() string {
visible := m.visibleConnections()
// count protocols and states for preview
tcpCount, udpCount := 0, 0
listenCount, estabCount, otherCount := 0, 0, 0
for _, c := range visible {
if c.Proto == "tcp" || c.Proto == "tcp6" {
tcpCount++
} else {
udpCount++
}
switch c.State {
case "LISTEN":
listenCount++
case "ESTABLISHED":
estabCount++
default:
otherCount++
}
}
var lines []string
// header
lines = append(lines, "")
headerText := " " + SymbolExport + " EXPORT CONNECTIONS "
lines = append(lines, m.theme.Styles.Header.Render(headerText))
lines = append(lines, m.theme.Styles.Border.Render(" "+strings.Repeat(BoxHorizontal, 36)))
lines = append(lines, "")
// stats preview section
lines = append(lines, m.theme.Styles.Normal.Render(" "+SymbolBullet+" summary"))
lines = append(lines, fmt.Sprintf(" total: %s",
m.theme.Styles.Success.Render(fmt.Sprintf("%d connections", len(visible)))))
protoSummary := fmt.Sprintf(" proto: %s tcp %s udp",
m.theme.Styles.GetProtoStyle("tcp").Render(fmt.Sprintf("%d", tcpCount)),
m.theme.Styles.GetProtoStyle("udp").Render(fmt.Sprintf("%d", udpCount)))
lines = append(lines, protoSummary)
stateSummary := fmt.Sprintf(" state: %s listen %s estab %s other",
m.theme.Styles.GetStateStyle("LISTEN").Render(fmt.Sprintf("%d", listenCount)),
m.theme.Styles.GetStateStyle("ESTABLISHED").Render(fmt.Sprintf("%d", estabCount)),
m.theme.Styles.Normal.Render(fmt.Sprintf("%d", otherCount)))
lines = append(lines, stateSummary)
lines = append(lines, "")
// format selection
lines = append(lines, m.theme.Styles.Normal.Render(" "+SymbolBullet+" format"))
csvStyle := m.theme.Styles.Normal
tsvStyle := m.theme.Styles.Normal
csvIndicator := " "
tsvIndicator := " "
if m.exportFormat == "tsv" {
tsvStyle = m.theme.Styles.Success
tsvIndicator = m.theme.Styles.Success.Render(SymbolSelected + " ")
} else {
csvStyle = m.theme.Styles.Success
csvIndicator = m.theme.Styles.Success.Render(SymbolSelected + " ")
}
formatLine := fmt.Sprintf(" %s%s %s%s",
csvIndicator, csvStyle.Render("CSV (comma)"),
tsvIndicator, tsvStyle.Render("TSV (tab)"))
lines = append(lines, formatLine)
lines = append(lines, m.theme.Styles.Border.Render(" "+strings.Repeat(BoxHorizontal, 8)+" press "+m.theme.Styles.Warning.Render("tab")+" to toggle"))
lines = append(lines, "")
// filename input
lines = append(lines, m.theme.Styles.Normal.Render(" "+SymbolBullet+" filename"))
ext := ".csv"
if m.exportFormat == "tsv" {
ext = ".tsv"
}
filenameDisplay := m.exportFilename
if filenameDisplay == "" {
filenameDisplay = "connections"
}
inputBox := fmt.Sprintf(" %s %s%s",
m.theme.Styles.Success.Render(SymbolPrompt),
m.theme.Styles.Warning.Render(filenameDisplay),
m.theme.Styles.Success.Render(ext+"▌"))
lines = append(lines, inputBox)
lines = append(lines, "")
// error display
if m.exportError != "" {
lines = append(lines, m.theme.Styles.Error.Render(fmt.Sprintf(" %s %s", SymbolWarning, m.exportError)))
lines = append(lines, "")
}
// preview of fields
lines = append(lines, m.theme.Styles.Border.Render(" "+strings.Repeat(BoxHorizontal, 36)))
fieldsPreview := " fields: PID, PROCESS, USER, PROTO, STATE, LADDR, LPORT, RADDR, RPORT"
lines = append(lines, m.theme.Styles.Normal.Render(truncate(fieldsPreview, 40)))
lines = append(lines, "")
// action buttons
lines = append(lines, fmt.Sprintf(" %s export %s toggle format %s cancel",
m.theme.Styles.Success.Render("[enter]"),
m.theme.Styles.Warning.Render("[tab]"),
m.theme.Styles.Error.Render("[esc]")))
lines = append(lines, "")
return strings.Join(lines, "\n")
}
func (m model) overlayModal(background, modal string) string { func (m model) overlayModal(background, modal string) string {
bgLines := strings.Split(background, "\n") bgLines := strings.Split(background, "\n")
modalLines := strings.Split(modal, "\n") modalLines := strings.Split(modal, "\n")

121
nix/containers.nix Normal file
View File

@@ -0,0 +1,121 @@
# oci container definitions for snitch
# builds containers based on different base images: alpine, debian trixie, ubuntu
#
# base images are pinned by imageDigest (immutable content hash), not by tag.
# even if the upstream tag gets a new image, builds remain reproducible.
#
# to update base image hashes, run:
# nix-prefetch-docker --image-name alpine --image-tag 3.21
# nix-prefetch-docker --image-name debian --image-tag trixie-slim
# nix-prefetch-docker --image-name ubuntu --image-tag 24.04
#
# this outputs both imageDigest and sha256 values needed below
{ pkgs, snitch }:
let
commonConfig = {
name = "snitch";
tag = snitch.version;
config = {
Entrypoint = [ "${snitch}/bin/snitch" ];
Env = [ "PATH=/bin" ];
Labels = {
"org.opencontainers.image.title" = "snitch";
"org.opencontainers.image.description" = "a friendlier ss/netstat for humans";
"org.opencontainers.image.source" = "https://github.com/karol-broda/snitch";
"org.opencontainers.image.licenses" = "MIT";
};
};
};
# alpine-based container
alpine = pkgs.dockerTools.pullImage {
imageName = "alpine";
imageDigest = "sha256:a8560b36e8b8210634f77d9f7f9efd7ffa463e380b75e2e74aff4511df3ef88c";
sha256 = "sha256-WNbRh44zld3lZtKARhdeWFte9JKgD2bgCuKzETWgGr8=";
finalImageName = "alpine";
finalImageTag = "3.21";
};
# debian trixie (testing) based container
debianTrixie = pkgs.dockerTools.pullImage {
imageName = "debian";
imageDigest = "sha256:e711a7b30ec1261130d0a121050b4ed81d7fb28aeabcf4ea0c7876d4e9f5aca2";
sha256 = "sha256-W/9A7aaPXFCmmg+NTSrFYL+QylsAgfnvkLldyI18tqU=";
finalImageName = "debian";
finalImageTag = "trixie-slim";
};
# ubuntu based container
ubuntu = pkgs.dockerTools.pullImage {
imageName = "ubuntu";
imageDigest = "sha256:c35e29c9450151419d9448b0fd75374fec4fff364a27f176fb458d472dfc9e54";
sha256 = "sha256-0j8xM+mECrBBHv7ZqofiRaeSoOXFBtLYjgnKivQztS0=";
finalImageName = "ubuntu";
finalImageTag = "24.04";
};
# scratch container (minimal, just the snitch binary)
scratch = pkgs.dockerTools.buildImage {
name = "snitch";
tag = "${snitch.version}-scratch";
copyToRoot = pkgs.buildEnv {
name = "snitch-root";
paths = [ snitch ];
pathsToLink = [ "/bin" ];
};
config = commonConfig.config;
};
in
{
snitch-alpine = pkgs.dockerTools.buildImage {
name = "snitch";
tag = "${snitch.version}-alpine";
fromImage = alpine;
copyToRoot = pkgs.buildEnv {
name = "snitch-root";
paths = [ snitch ];
pathsToLink = [ "/bin" ];
};
config = commonConfig.config;
};
snitch-debian = pkgs.dockerTools.buildImage {
name = "snitch";
tag = "${snitch.version}-debian";
fromImage = debianTrixie;
copyToRoot = pkgs.buildEnv {
name = "snitch-root";
paths = [ snitch ];
pathsToLink = [ "/bin" ];
};
config = commonConfig.config;
};
snitch-ubuntu = pkgs.dockerTools.buildImage {
name = "snitch";
tag = "${snitch.version}-ubuntu";
fromImage = ubuntu;
copyToRoot = pkgs.buildEnv {
name = "snitch-root";
paths = [ snitch ];
pathsToLink = [ "/bin" ];
};
config = commonConfig.config;
};
snitch-scratch = scratch;
oci-default = pkgs.dockerTools.buildImage {
name = "snitch";
tag = snitch.version;
fromImage = alpine;
copyToRoot = pkgs.buildEnv {
name = "snitch-root";
paths = [ snitch ];
pathsToLink = [ "/bin" ];
};
config = commonConfig.config;
};
}

177
nix/hm-module.nix Normal file
View File

@@ -0,0 +1,177 @@
{
config,
lib,
pkgs,
...
}:
let
cfg = config.programs.snitch;
themes = [
"ansi"
"catppuccin-mocha"
"catppuccin-macchiato"
"catppuccin-frappe"
"catppuccin-latte"
"gruvbox-dark"
"gruvbox-light"
"dracula"
"nord"
"tokyo-night"
"tokyo-night-storm"
"tokyo-night-light"
"solarized-dark"
"solarized-light"
"one-dark"
"mono"
"auto"
];
defaultFields = [
"pid"
"process"
"user"
"proto"
"state"
"laddr"
"lport"
"raddr"
"rport"
];
tomlFormat = pkgs.formats.toml { };
settingsType = lib.types.submodule {
freeformType = tomlFormat.type;
options = {
defaults = lib.mkOption {
type = lib.types.submodule {
freeformType = tomlFormat.type;
options = {
interval = lib.mkOption {
type = lib.types.str;
default = "1s";
example = "2s";
description = "Default refresh interval for watch/stats/trace commands.";
};
numeric = lib.mkOption {
type = lib.types.bool;
default = false;
description = "Disable name/service resolution by default.";
};
fields = lib.mkOption {
type = lib.types.listOf lib.types.str;
default = defaultFields;
example = [ "pid" "process" "proto" "state" "laddr" "lport" ];
description = "Default fields to display.";
};
theme = lib.mkOption {
type = lib.types.enum themes;
default = "ansi";
description = ''
Color theme for the TUI. "ansi" inherits terminal colors.
'';
};
units = lib.mkOption {
type = lib.types.enum [ "auto" "si" "iec" ];
default = "auto";
description = "Default units for byte display.";
};
color = lib.mkOption {
type = lib.types.enum [ "auto" "always" "never" ];
default = "auto";
description = "Default color mode.";
};
resolve = lib.mkOption {
type = lib.types.bool;
default = true;
description = "Enable name resolution by default.";
};
dns_cache = lib.mkOption {
type = lib.types.bool;
default = true;
description = "Enable DNS caching.";
};
ipv4 = lib.mkOption {
type = lib.types.bool;
default = false;
description = "Filter to IPv4 only by default.";
};
ipv6 = lib.mkOption {
type = lib.types.bool;
default = false;
description = "Filter to IPv6 only by default.";
};
no_headers = lib.mkOption {
type = lib.types.bool;
default = false;
description = "Omit headers in output by default.";
};
output_format = lib.mkOption {
type = lib.types.enum [ "table" "json" "csv" ];
default = "table";
description = "Default output format.";
};
sort_by = lib.mkOption {
type = lib.types.str;
default = "";
example = "pid";
description = "Default sort field.";
};
};
};
default = { };
description = "Default settings for snitch commands.";
};
};
};
in
{
options.programs.snitch = {
enable = lib.mkEnableOption "snitch, a friendlier ss/netstat for humans";
package = lib.mkPackageOption pkgs "snitch" { };
settings = lib.mkOption {
type = settingsType;
default = { };
example = lib.literalExpression ''
{
defaults = {
theme = "catppuccin-mocha";
interval = "2s";
resolve = true;
};
}
'';
description = ''
Configuration written to {file}`$XDG_CONFIG_HOME/snitch/snitch.toml`.
See <https://github.com/karol-broda/snitch> for available options.
'';
};
};
config = lib.mkIf cfg.enable {
home.packages = [ cfg.package ];
xdg.configFile."snitch/snitch.toml" = lib.mkIf (cfg.settings != { }) {
source = tomlFormat.generate "snitch.toml" cfg.settings;
};
};
}

View File

@@ -0,0 +1,429 @@
# home manager module tests
#
# run with: nix build .#checks.x86_64-linux.hm-module
#
# tests cover:
# - module evaluation with various configurations
# - type validation for all options
# - generated TOML content verification
# - edge cases (disabled, empty settings, full settings)
{ pkgs, lib, hmModule }:
let
# minimal home-manager stub for standalone module testing
hmLib = {
hm.types.dagOf = lib.types.attrsOf;
dag.entryAnywhere = x: x;
};
# evaluate the hm module with a given config
evalModule = testConfig:
lib.evalModules {
modules = [
hmModule
# stub home-manager's expected structure
{
options = {
home.packages = lib.mkOption {
type = lib.types.listOf lib.types.package;
default = [ ];
};
xdg.configFile = lib.mkOption {
type = lib.types.attrsOf (lib.types.submodule {
options = {
source = lib.mkOption { type = lib.types.path; };
text = lib.mkOption { type = lib.types.str; default = ""; };
};
});
default = { };
};
};
}
testConfig
];
specialArgs = { inherit pkgs lib; };
};
# read generated TOML file content
readGeneratedToml = evalResult:
let
configFile = evalResult.config.xdg.configFile."snitch/snitch.toml" or null;
in
if configFile != null && configFile ? source
then builtins.readFile configFile.source
else null;
# test cases
tests = {
# test 1: module evaluates when disabled
moduleDisabled = {
name = "module-disabled";
config = {
programs.snitch.enable = false;
};
assertions = evalResult: [
{
assertion = evalResult.config.home.packages == [ ];
message = "packages should be empty when disabled";
}
{
assertion = !(evalResult.config.xdg.configFile ? "snitch/snitch.toml");
message = "config file should not exist when disabled";
}
];
};
# test 2: module evaluates with enable only (defaults)
moduleEnabledDefaults = {
name = "module-enabled-defaults";
config = {
programs.snitch.enable = true;
};
assertions = evalResult: [
{
assertion = builtins.length evalResult.config.home.packages == 1;
message = "package should be installed when enabled";
}
];
};
# test 3: all theme values are valid
themeValidation = {
name = "theme-validation";
config = {
programs.snitch = {
enable = true;
settings.defaults.theme = "catppuccin-mocha";
};
};
assertions = evalResult:
let
toml = readGeneratedToml evalResult;
in
[
{
assertion = toml != null;
message = "TOML config should be generated";
}
{
assertion = lib.hasInfix "catppuccin-mocha" toml;
message = "theme should be set in TOML";
}
];
};
# test 4: full configuration with all options
fullConfiguration = {
name = "full-configuration";
config = {
programs.snitch = {
enable = true;
settings.defaults = {
interval = "2s";
numeric = true;
fields = [ "pid" "process" "proto" ];
theme = "nord";
units = "si";
color = "always";
resolve = false;
dns_cache = false;
ipv4 = true;
ipv6 = false;
no_headers = true;
output_format = "json";
sort_by = "pid";
};
};
};
assertions = evalResult:
let
toml = readGeneratedToml evalResult;
in
[
{
assertion = toml != null;
message = "TOML config should be generated";
}
{
assertion = lib.hasInfix "interval = \"2s\"" toml;
message = "interval should be 2s";
}
{
assertion = lib.hasInfix "numeric = true" toml;
message = "numeric should be true";
}
{
assertion = lib.hasInfix "theme = \"nord\"" toml;
message = "theme should be nord";
}
{
assertion = lib.hasInfix "units = \"si\"" toml;
message = "units should be si";
}
{
assertion = lib.hasInfix "color = \"always\"" toml;
message = "color should be always";
}
{
assertion = lib.hasInfix "resolve = false" toml;
message = "resolve should be false";
}
{
assertion = lib.hasInfix "output_format = \"json\"" toml;
message = "output_format should be json";
}
{
assertion = lib.hasInfix "sort_by = \"pid\"" toml;
message = "sort_by should be pid";
}
];
};
# test 5: output format enum validation
outputFormatCsv = {
name = "output-format-csv";
config = {
programs.snitch = {
enable = true;
settings.defaults.output_format = "csv";
};
};
assertions = evalResult:
let
toml = readGeneratedToml evalResult;
in
[
{
assertion = lib.hasInfix "output_format = \"csv\"" toml;
message = "output_format should accept csv";
}
];
};
# test 6: units enum validation
unitsIec = {
name = "units-iec";
config = {
programs.snitch = {
enable = true;
settings.defaults.units = "iec";
};
};
assertions = evalResult:
let
toml = readGeneratedToml evalResult;
in
[
{
assertion = lib.hasInfix "units = \"iec\"" toml;
message = "units should accept iec";
}
];
};
# test 7: color never value
colorNever = {
name = "color-never";
config = {
programs.snitch = {
enable = true;
settings.defaults.color = "never";
};
};
assertions = evalResult:
let
toml = readGeneratedToml evalResult;
in
[
{
assertion = lib.hasInfix "color = \"never\"" toml;
message = "color should accept never";
}
];
};
# test 8: freeform type allows custom keys
freeformCustomKeys = {
name = "freeform-custom-keys";
config = {
programs.snitch = {
enable = true;
settings = {
defaults.theme = "dracula";
custom_section = {
custom_key = "custom_value";
};
};
};
};
assertions = evalResult:
let
toml = readGeneratedToml evalResult;
in
[
{
assertion = lib.hasInfix "custom_key" toml;
message = "freeform type should allow custom keys";
}
];
};
# test 9: all themes evaluate correctly
allThemes =
let
themes = [
"ansi"
"catppuccin-mocha"
"catppuccin-macchiato"
"catppuccin-frappe"
"catppuccin-latte"
"gruvbox-dark"
"gruvbox-light"
"dracula"
"nord"
"tokyo-night"
"tokyo-night-storm"
"tokyo-night-light"
"solarized-dark"
"solarized-light"
"one-dark"
"mono"
"auto"
];
in
{
name = "all-themes";
# use the last theme as the test config
config = {
programs.snitch = {
enable = true;
settings.defaults.theme = "auto";
};
};
assertions = evalResult:
let
# verify all themes can be set by evaluating them
themeResults = map
(theme:
let
result = evalModule {
programs.snitch = {
enable = true;
settings.defaults.theme = theme;
};
};
toml = readGeneratedToml result;
in
{
inherit theme;
success = toml != null && lib.hasInfix theme toml;
}
)
themes;
allSucceeded = lib.all (r: r.success) themeResults;
in
[
{
assertion = allSucceeded;
message = "all themes should evaluate correctly: ${
lib.concatMapStringsSep ", "
(r: "${r.theme}=${if r.success then "ok" else "fail"}")
themeResults
}";
}
];
};
# test 10: fields list serialization
fieldsListSerialization = {
name = "fields-list-serialization";
config = {
programs.snitch = {
enable = true;
settings.defaults.fields = [ "pid" "process" "proto" "state" ];
};
};
assertions = evalResult:
let
toml = readGeneratedToml evalResult;
in
[
{
assertion = lib.hasInfix "pid" toml && lib.hasInfix "process" toml;
message = "fields list should be serialized correctly";
}
];
};
};
# run all tests and collect results
runTests =
let
testResults = lib.mapAttrsToList
(name: test:
let
evalResult = evalModule test.config;
assertions = test.assertions evalResult;
failures = lib.filter (a: !a.assertion) assertions;
in
{
inherit name;
testName = test.name;
passed = failures == [ ];
failures = map (f: f.message) failures;
}
)
tests;
allPassed = lib.all (r: r.passed) testResults;
failedTests = lib.filter (r: !r.passed) testResults;
summary = ''
========================================
home manager module test results
========================================
total tests: ${toString (builtins.length testResults)}
passed: ${toString (builtins.length (lib.filter (r: r.passed) testResults))}
failed: ${toString (builtins.length failedTests)}
========================================
${lib.concatMapStringsSep "\n" (r:
if r.passed
then "[yes] ${r.testName}"
else "[no] ${r.testName}\n ${lib.concatStringsSep "\n " r.failures}"
) testResults}
========================================
'';
in
{
inherit testResults allPassed failedTests summary;
};
results = runTests;
in
pkgs.runCommand "hm-module-test"
{
passthru = {
inherit results;
# expose for debugging
inherit evalModule tests;
};
}
(
if results.allPassed
then ''
echo "${results.summary}"
echo "all tests passed"
touch $out
''
else ''
echo "${results.summary}"
echo ""
echo "failed tests:"
${lib.concatMapStringsSep "\n" (t: ''
echo " - ${t.testName}: ${lib.concatStringsSep ", " t.failures}"
'') results.failedTests}
exit 1
''
)