refactor: Rewrote in golang.

This commit is contained in:
2025-11-03 16:19:17 -05:00
parent e38f8a4c8f
commit d6ef585a45
27 changed files with 1945 additions and 343 deletions

207
app/config.go Normal file
View File

@@ -0,0 +1,207 @@
package app
import (
"encoding/json"
"errors"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"filippo.io/age"
"filippo.io/age/agessh"
)
type Config struct {
Keys []SshKeyPair `json:"keys"`
ScanConfig scanConfig `json:"scan"`
}
type SshKeyPair struct {
Private string `json:"private"` // Path to the private key file
Public string `json:"public"` // Path to the public key file
}
type scanConfig struct {
Matcher string `json:"matcher"`
Exclude string `json:"exclude"`
Include string `json:"include"`
}
// Create a fresh config with sensible defaults.
func NewConfig(privateKeyPaths []string) Config {
var keys = []SshKeyPair{}
for _, priv := range privateKeyPaths {
var key = SshKeyPair{
Private: priv,
Public: priv + ".pub",
}
keys = append(keys, key)
}
return Config{
Keys: keys,
ScanConfig: scanConfig{
Matcher: "\\.env",
Exclude: "*.envrc",
Include: "~",
},
}
}
// Read the Config from disk.
func LoadConfig() (*Config, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return nil, err
}
configPath := filepath.Join(homeDir, ".envr", "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("No config file found. Please run `envr init` to generate one.")
} else {
return nil, err
}
}
var config Config
if err := json.Unmarshal(data, &config); err != nil {
return nil, err
}
return &config, nil
}
// Write the Config to disk.
func (c *Config) Save() error {
// Create the ~/.envr directory
homeDir, err := os.UserHomeDir()
if err != nil {
return err
}
configDir := filepath.Join(homeDir, ".envr")
if err := os.MkdirAll(configDir, 0755); err != nil {
return err
}
configPath := filepath.Join(configDir, "config.json")
// Check if file exists and is not empty
if info, err := os.Stat(configPath); err == nil {
if info.Size() > 0 {
return os.ErrExist
}
}
data, err := json.MarshalIndent(c, "", " ")
if err != nil {
return err
}
return os.WriteFile(configPath, data, 0644)
}
// Use fd to find all ignored .env files that match the config's parameters
func (c Config) scan() (paths []string, err error) {
searchPath, err := c.searchPath()
if err != nil {
return []string{}, err
}
// Find all files (including ignored ones)
fmt.Printf("Searching for all files in \"%s\"...\n", searchPath)
allCmd := exec.Command("fd", "-a", c.ScanConfig.Matcher, "-E", c.ScanConfig.Exclude, "-HI", searchPath)
allOutput, err := allCmd.Output()
if err != nil {
return []string{}, err
}
allFiles := strings.Split(strings.TrimSpace(string(allOutput)), "\n")
if len(allFiles) == 1 && allFiles[0] == "" {
allFiles = []string{}
}
// Find unignored files
fmt.Printf("Search for unignored fies in \"%s\"...\n", searchPath)
unignoredCmd := exec.Command("fd", "-a", c.ScanConfig.Matcher, "-E", c.ScanConfig.Exclude, "-H", searchPath)
unignoredOutput, err := unignoredCmd.Output()
if err != nil {
return []string{}, err
}
unignoredFiles := strings.Split(strings.TrimSpace(string(unignoredOutput)), "\n")
if len(unignoredFiles) == 1 && unignoredFiles[0] == "" {
unignoredFiles = []string{}
}
// Create a map for faster lookup
unignoredMap := make(map[string]bool)
for _, file := range unignoredFiles {
unignoredMap[file] = true
}
// Filter to get only ignored files
var ignoredFiles []string
for _, file := range allFiles {
if !unignoredMap[file] {
ignoredFiles = append(ignoredFiles, file)
}
}
return ignoredFiles, nil
}
func (c Config) searchPath() (path string, err error) {
include := c.ScanConfig.Include
if include == "~" {
homeDir, err := os.UserHomeDir()
if err != nil {
return "", err
}
return homeDir, nil
}
absPath, err := filepath.Abs(include)
if err != nil {
return "", err
}
return absPath, nil
}
// TODO: Should this be private?
func (s SshKeyPair) Identity() (age.Identity, error) {
sshKey, err := os.ReadFile(s.Private)
if err != nil {
return nil, fmt.Errorf("failed to read SSH key: %w", err)
}
id, err := agessh.ParseIdentity(sshKey)
if err != nil {
return nil, fmt.Errorf("failed to parse SSH identity: %w", err)
}
return id, nil
}
// TODO: Should this be private?
func (s SshKeyPair) Recipient() (age.Recipient, error) {
sshKey, err := os.ReadFile(s.Public)
if err != nil {
return nil, fmt.Errorf("failed to read SSH key: %w", err)
}
id, err := agessh.ParseRecipient(string(sshKey))
if err != nil {
return nil, fmt.Errorf("failed to parse SSH identity: %w", err)
}
return id, nil
}

376
app/db.go Normal file
View File

@@ -0,0 +1,376 @@
package app
import (
"database/sql"
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"slices"
"filippo.io/age"
_ "modernc.org/sqlite"
)
// CloseMode determines whether or not the in-memory DB should be saved to disk
// before closing the connection.
type CloseMode int
const (
ReadOnly CloseMode = iota
Write
)
type Db struct {
db *sql.DB
cfg Config
features *AvailableFeatures
}
func Open() (*Db, error) {
cfg, err := LoadConfig()
if err != nil {
return nil, err
}
if _, err := os.Stat("/home/spencer/.envr/data.age"); err != nil {
// Create a new DB
db, err := newDb()
return &Db{db, *cfg, nil}, err
} else {
// Open the existing DB
tmpFile, err := os.CreateTemp("", "envr-*.db")
if err != nil {
return nil, fmt.Errorf("failed to create temp file: %w", err)
}
defer tmpFile.Close()
defer os.Remove(tmpFile.Name())
err = decryptDb(tmpFile.Name(), (*cfg).Keys)
if err != nil {
return nil, fmt.Errorf("failed to decrypt database: %w", err)
}
memDb, err := newDb()
if err != nil {
return nil, fmt.Errorf("failed to open temp database: %w", err)
}
restoreDB(tmpFile.Name(), memDb)
return &Db{memDb, *cfg, nil}, nil
}
}
// Creates the database for the first time
func newDb() (*sql.DB, error) {
db, err := sql.Open("sqlite", ":memory:")
if err != nil {
return nil, err
} else {
_, err := db.Exec(`create table envr_env_files (
path text primary key not null
, dir text not null
, remotes text -- JSON
, sha256 text not null
, contents text not null
);`)
if err != nil {
return nil, err
} else {
return db, err
}
}
}
// Decrypt the database from the age file into a temp sqlite file.
func decryptDb(tmpFilePath string, keys []SshKeyPair) error {
homeDir, err := os.UserHomeDir()
if err != nil {
return fmt.Errorf("failed to get user home directory: %w", err)
}
tmpFile, err := os.OpenFile(tmpFilePath, os.O_WRONLY, 0)
if err != nil {
return fmt.Errorf("failed to open temp file: %w", err)
}
defer tmpFile.Close()
ageFilePath := filepath.Join(homeDir, ".envr", "data.age")
ageFile, err := os.Open(ageFilePath)
if err != nil {
return fmt.Errorf("failed to open age file: %w", err)
}
defer ageFile.Close()
identities := make([]age.Identity, 0, len(keys))
for _, key := range keys {
id, err := key.Identity()
if err != nil {
return err
}
identities = append(identities, id)
}
reader, err := age.Decrypt(ageFile, identities[:]...)
if err != nil {
return fmt.Errorf("failed to decrypt age file: %w", err)
}
_, err = io.Copy(tmpFile, reader)
if err != nil {
return fmt.Errorf("failed to copy decrypted content: %w", err)
}
return nil
}
// Restore the database from a file into memory
func restoreDB(path string, destDB *sql.DB) error {
// Attach the source database
_, err := destDB.Exec("ATTACH DATABASE ? AS source", path)
if err != nil {
return fmt.Errorf("failed to attach database: %w", err)
}
defer destDB.Exec("DETACH DATABASE source")
// Copy data from source to destination
_, err = destDB.Exec("INSERT INTO main.envr_env_files SELECT * FROM source.envr_env_files")
if err != nil {
return fmt.Errorf("failed to copy data: %w", err)
}
return nil
}
// Returns all the EnvFiles present in the database.
func (db *Db) List() (results []EnvFile, err error) {
rows, err := db.db.Query("select * from envr_env_files")
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var envFile EnvFile
var remotesJSON string
err := rows.Scan(&envFile.Path, &envFile.Dir, &remotesJSON, &envFile.Sha256, &envFile.contents)
if err != nil {
return nil, err
}
// TODO: unmarshal remotesJSON into envFile.remotes
results = append(results, envFile)
}
if err = rows.Err(); err != nil {
return nil, err
}
return results, nil
}
func (db *Db) Close(mode CloseMode) error {
defer db.db.Close()
if mode == Write {
// Create tmp file
tmpFile, err := os.CreateTemp("", "envr-*.db")
if err != nil {
return fmt.Errorf("failed to create temp file: %w", err)
}
defer tmpFile.Close()
defer os.Remove(tmpFile.Name())
if err := backupDb(db.db, tmpFile.Name()); err != nil {
return err
}
if err := encryptDb(tmpFile.Name(), db.cfg.Keys); err != nil {
return err
}
}
return nil
}
// Save the in-memory database to a tmp file.
func backupDb(memDb *sql.DB, tmpFilePath string) error {
_, err := memDb.Exec("VACUUM INTO ?", tmpFilePath)
if err != nil {
return fmt.Errorf("failed to vacuum database to file: %w", err)
}
return nil
}
// Encrypt the database from the temp sqlite file into an age file.
func encryptDb(tmpFilePath string, keys []SshKeyPair) error {
homeDir, err := os.UserHomeDir()
if err != nil {
return fmt.Errorf("failed to get user home directory: %w", err)
}
ageFilePath := filepath.Join(homeDir, ".envr", "data.age")
// Ensure .envr directory exists
err = os.MkdirAll(filepath.Dir(ageFilePath), 0755)
if err != nil {
return fmt.Errorf("failed to create .envr directory: %w", err)
}
// Open temp file for reading
tmpFile, err := os.Open(tmpFilePath)
if err != nil {
return fmt.Errorf("failed to open temp file: %w", err)
}
defer tmpFile.Close()
// Open/create age file for writing (this preserves hardlinks)
ageFile, err := os.OpenFile(ageFilePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return fmt.Errorf("failed to open age file: %w", err)
}
defer ageFile.Close()
recipients := make([]age.Recipient, 0, len(keys))
for _, key := range keys {
recipient, err := key.Recipient()
if err != nil {
return err
}
recipients = append(recipients, recipient)
}
writer, err := age.Encrypt(ageFile, recipients...)
if err != nil {
return fmt.Errorf("failed to create age writer: %w", err)
}
_, err = io.Copy(writer, tmpFile)
if err != nil {
return fmt.Errorf("failed to encrypt and write data: %w", err)
}
err = writer.Close()
if err != nil {
return fmt.Errorf("failed to close age writer: %w", err)
}
return nil
}
func (db *Db) Insert(file EnvFile) error {
// Marshal remotes to JSON
remotesJSON, err := json.Marshal(file.Remotes)
if err != nil {
return fmt.Errorf("failed to marshal remotes: %w", err)
}
// Insert into database
_, err = db.db.Exec(`
INSERT OR REPLACE INTO envr_env_files (path, dir, remotes, sha256, contents)
VALUES (?, ?, ?, ?, ?)
`, file.Path, file.Dir, string(remotesJSON), file.Sha256, file.contents)
if err != nil {
return fmt.Errorf("failed to insert env file: %w", err)
}
return nil
}
// Select a single EnvFile from the database.
func (db *Db) Fetch(path string) (envFile EnvFile, err error) {
var remotesJSON string
row := db.db.QueryRow("SELECT path, dir, remotes, sha256, contents FROM envr_env_files WHERE path = ?", path)
err = row.Scan(&envFile.Path, &envFile.Dir, &remotesJSON, &envFile.Sha256, &envFile.contents)
if err != nil {
return EnvFile{}, fmt.Errorf("failed to fetch env file: %w", err)
}
if err = json.Unmarshal([]byte(remotesJSON), &envFile.Remotes); err != nil {
return EnvFile{}, fmt.Errorf("failed to unmarshal remotes: %w", err)
}
return envFile, nil
}
// Removes a file from the database, if present.
func (db *Db) Delete(path string) error {
result, err := db.db.Exec("DELETE FROM envr_env_files WHERE path = ?", path)
if err != nil {
return fmt.Errorf("failed to delete env file: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("no file found with path: %s", path)
}
return nil
}
// Finds .env files in the filesystem that aren't present in the database.
func (db *Db) Scan() ([]string, error) {
all_paths, err := db.cfg.scan()
if err != nil {
return []string{}, err
}
untracked_paths := make([]string, 0, len(all_paths)/2)
env_files, err := db.List()
if err != nil {
return untracked_paths, err
}
for _, path := range all_paths {
backed_up := slices.ContainsFunc(env_files, func(e EnvFile) bool {
return e.Path == path
})
if backed_up {
continue
} else {
untracked_paths = append(untracked_paths, path)
}
}
return untracked_paths, nil
}
// Determine the available features on the installed system.
func (db *Db) Features() AvailableFeatures {
if db.features == nil {
feats := checkFeatures()
db.features = &feats
}
return *db.features
}
// Returns nil if [Db.Scan] is safe to use, null otherwise.
func (db *Db) CanScan() error {
if db.Features()&Fd == 0 {
return fmt.Errorf(
"please install fd to use the scan function (https://github.com/sharkdp/fd)",
)
} else {
return nil
}
}

169
app/env_file.go Normal file
View File

@@ -0,0 +1,169 @@
package app
import (
"crypto/sha256"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
)
type EnvFile struct {
Path string
Dir string
Remotes []string // []string
Sha256 string
contents string
}
// The result returned by [EnvFile.Sync]
type EnvFileSyncResult int
const (
// The struct has been updated from the filesystem
// and should be updated in the database.
Updated EnvFileSyncResult = iota
// The filesystem has been restored to match the struct
// no further action is required.
Restored
Error
// The filesystem contents matches the struct
// no further action is required.
Noop
)
func NewEnvFile(path string) EnvFile {
// Get absolute path and directory
absPath, err := filepath.Abs(path)
if err != nil {
panic(fmt.Errorf("failed to get absolute path: %w", err))
}
dir := filepath.Dir(absPath)
// Get git remotes
remotes := getGitRemotes(dir)
// Read the file contents
contents, err := os.ReadFile(path)
if err != nil {
panic(fmt.Errorf("failed to read file %s: %w", path, err))
}
// Calculate SHA256 hash
hash := sha256.Sum256(contents)
sha256Hash := fmt.Sprintf("%x", hash)
return EnvFile{
Path: absPath,
Dir: dir,
Remotes: remotes,
Sha256: sha256Hash,
contents: string(contents),
}
}
func getGitRemotes(dir string) []string {
// TODO: Check for Git flag and change behaviour if unset.
cmd := exec.Command("git", "remote", "-v")
cmd.Dir = dir
output, err := cmd.Output()
if err != nil {
// Not a git repository or git command failed
return []string{}
}
lines := strings.Split(strings.TrimSpace(string(output)), "\n")
remoteSet := make(map[string]bool)
for _, line := range lines {
if line == "" {
continue
}
parts := strings.Fields(line)
if len(parts) >= 2 {
remoteSet[parts[1]] = true
}
}
remotes := make([]string, 0, len(remoteSet))
for remote := range remoteSet {
remotes = append(remotes, remote)
}
return remotes
}
// Install the file into the file system
func (file EnvFile) Restore() error {
// TODO: Handle restores more cleanly
// Ensure the directory exists
if err := os.MkdirAll(filepath.Dir(file.Path), 0755); err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}
// Check if file already exists
if _, err := os.Stat(file.Path); err == nil {
return fmt.Errorf("file already exists: %s", file.Path)
}
// Write the contents to the file
if err := os.WriteFile(file.Path, []byte(file.contents), 0644); err != nil {
return fmt.Errorf("failed to write file: %w", err)
}
return nil
}
// Try to reconcile the EnvFile with the filesystem.
//
// If Updated is returned, [Db.Insert] should be called on file.
func (file *EnvFile) Sync() (result EnvFileSyncResult, err error) {
// Check if the path exists in the file system
_, err = os.Stat(file.Path)
if err == nil {
contents, err := os.ReadFile(file.Path)
if err != nil {
return Error, fmt.Errorf("failed to read file for SHA comparison: %w", err)
}
// Check if sha matches by reading the current file and calculating its hash
hash := sha256.Sum256(contents)
currentSha := fmt.Sprintf("%x", hash)
if file.Sha256 == currentSha {
// Nothing to do
return Noop, nil
} else {
if err = file.Backup(); err != nil {
return Error, err
} else {
return Updated, nil
}
}
} else {
if err = file.Restore(); err != nil {
return Error, err
} else {
return Restored, nil
}
}
}
// Update the EnvFile using the file system
func (file *EnvFile) Backup() error {
// Read the contents of the file
contents, err := os.ReadFile(file.Path)
if err != nil {
return fmt.Errorf("failed to read file %s: %w", file.Path, err)
}
// Update file.contents to match
file.contents = string(contents)
// Update file.sha256
hash := sha256.Sum256(contents)
file.Sha256 = fmt.Sprintf("%x", hash)
return nil
}

32
app/features.go Normal file
View File

@@ -0,0 +1,32 @@
package app
import (
"os/exec"
)
// Represents which binaries are present in $PATH.
// Used to fail safely when required features are unavailable
type AvailableFeatures int
const (
Git AvailableFeatures = 1
// fd
Fd AvailableFeatures = 2
// All features are present
All AvailableFeatures = Git & Fd
)
// Checks for available features.
func checkFeatures() (feats AvailableFeatures) {
// Check for git binary
if _, err := exec.LookPath("git"); err == nil {
feats |= Git
}
// Check for fd binary
if _, err := exec.LookPath("fd"); err == nil {
feats |= Fd
}
return feats
}