From 13c2f9cffa624fdf498f3b61fab9d809b92e026e Mon Sep 17 00:00:00 2001 From: bndw Date: Sun, 28 Dec 2025 09:21:08 -0800 Subject: init --- internal/ssh/client.go | 357 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 357 insertions(+) create mode 100644 internal/ssh/client.go (limited to 'internal/ssh') diff --git a/internal/ssh/client.go b/internal/ssh/client.go new file mode 100644 index 0000000..1cd336c --- /dev/null +++ b/internal/ssh/client.go @@ -0,0 +1,357 @@ +package ssh + +import ( + "bufio" + "bytes" + "fmt" + "net" + "os" + "os/exec" + "path/filepath" + "strings" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" +) + +// Client represents an SSH connection to a remote host +type Client struct { + host string + client *ssh.Client +} + +// sshConfig holds SSH configuration for a host +type sshConfig struct { + Host string + HostName string + User string + Port string + IdentityFile string +} + +// Connect establishes an SSH connection to the remote host +// Supports both SSH config aliases (e.g., "myserver") and user@host format +func Connect(host string) (*Client, error) { + var user, addr string + var identityFile string + + // Try to read SSH config first + cfg, err := readSSHConfig(host) + if err == nil && cfg.HostName != "" { + // Use SSH config + user = cfg.User + addr = cfg.HostName + if cfg.Port != "" { + addr = addr + ":" + cfg.Port + } else { + addr = addr + ":22" + } + identityFile = cfg.IdentityFile + } else { + // Fall back to parsing user@host format + parts := strings.SplitN(host, "@", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("host '%s' not found in SSH config and not in user@host format", host) + } + user = parts[0] + addr = parts[1] + + // Add default port if not specified + if !strings.Contains(addr, ":") { + addr = addr + ":22" + } + } + + // Build authentication methods + var authMethods []ssh.AuthMethod + + // Try identity file from SSH config first + if identityFile != "" { + if authMethod, err := publicKeyFromFile(identityFile); err == nil { + authMethods = append(authMethods, authMethod) + } + } + + // Try SSH agent + if authMethod, err := sshAgent(); err == nil { + authMethods = append(authMethods, authMethod) + } + + // Try default key files + if authMethod, err := publicKeyFile(); err == nil { + authMethods = append(authMethods, authMethod) + } + + if len(authMethods) == 0 { + return nil, fmt.Errorf("no SSH authentication method available") + } + + config := &ssh.ClientConfig{ + User: user, + Auth: authMethods, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), // TODO: Consider using known_hosts + } + + client, err := ssh.Dial("tcp", addr, config) + if err != nil { + return nil, fmt.Errorf("failed to connect to %s: %w", host, err) + } + + return &Client{ + host: host, + client: client, + }, nil +} + +// Close closes the SSH connection +func (c *Client) Close() error { + return c.client.Close() +} + +// Run executes a command on the remote host and returns the output +func (c *Client) Run(cmd string) (string, error) { + session, err := c.client.NewSession() + if err != nil { + return "", err + } + defer session.Close() + + var stdout, stderr bytes.Buffer + session.Stdout = &stdout + session.Stderr = &stderr + + if err := session.Run(cmd); err != nil { + return "", fmt.Errorf("command failed: %w\nstderr: %s", err, stderr.String()) + } + + return stdout.String(), nil +} + +// RunSudo executes a command with sudo on the remote host +func (c *Client) RunSudo(cmd string) (string, error) { + return c.Run("sudo " + cmd) +} + +// Upload copies a local file to the remote host using scp +func (c *Client) Upload(localPath, remotePath string) error { + // Use external scp command for simplicity + // Format: scp -o StrictHostKeyChecking=no localPath user@host:remotePath + cmd := exec.Command("scp", "-o", "StrictHostKeyChecking=no", localPath, c.host+":"+remotePath) + + var stderr bytes.Buffer + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + return fmt.Errorf("scp failed: %w\nstderr: %s", err, stderr.String()) + } + + return nil +} + +// UploadDir copies a local directory to the remote host using rsync +func (c *Client) UploadDir(localDir, remoteDir string) error { + // Use rsync for directory uploads + // Format: rsync -avz --delete localDir/ user@host:remoteDir/ + localDir = strings.TrimSuffix(localDir, "/") + "/" + remoteDir = strings.TrimSuffix(remoteDir, "/") + "/" + + cmd := exec.Command("rsync", "-avz", "--delete", + "-e", "ssh -o StrictHostKeyChecking=no", + localDir, c.host+":"+remoteDir) + + var stderr bytes.Buffer + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + return fmt.Errorf("rsync failed: %w\nstderr: %s", err, stderr.String()) + } + + return nil +} + +// WriteFile creates a file with the given content on the remote host +func (c *Client) WriteFile(remotePath, content string) error { + session, err := c.client.NewSession() + if err != nil { + return err + } + defer session.Close() + + // Use cat to write content to file + cmd := fmt.Sprintf("cat > %s", remotePath) + session.Stdin = strings.NewReader(content) + + var stderr bytes.Buffer + session.Stderr = &stderr + + if err := session.Run(cmd); err != nil { + return fmt.Errorf("write file failed: %w\nstderr: %s", err, stderr.String()) + } + + return nil +} + +// WriteSudoFile creates a file with the given content using sudo +func (c *Client) WriteSudoFile(remotePath, content string) error { + session, err := c.client.NewSession() + if err != nil { + return err + } + defer session.Close() + + // Use sudo tee to write content to file + cmd := fmt.Sprintf("sudo tee %s > /dev/null", remotePath) + session.Stdin = strings.NewReader(content) + + var stderr bytes.Buffer + session.Stderr = &stderr + + if err := session.Run(cmd); err != nil { + return fmt.Errorf("write file with sudo failed: %w\nstderr: %s", err, stderr.String()) + } + + return nil +} + +// readSSHConfig reads and parses the SSH config file for a given host +func readSSHConfig(host string) (*sshConfig, error) { + home, err := os.UserHomeDir() + if err != nil { + return nil, err + } + + configPath := filepath.Join(home, ".ssh", "config") + file, err := os.Open(configPath) + if err != nil { + return nil, err + } + defer file.Close() + + cfg := &sshConfig{} + var currentHost string + var matchedHost bool + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + + // Skip comments and empty lines + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + fields := strings.Fields(line) + if len(fields) < 2 { + continue + } + + key := strings.ToLower(fields[0]) + value := fields[1] + + // Expand ~ in paths + if strings.HasPrefix(value, "~/") { + value = filepath.Join(home, value[2:]) + } + + switch key { + case "host": + currentHost = value + if currentHost == host { + matchedHost = true + cfg.Host = host + } else { + matchedHost = false + } + case "hostname": + if matchedHost { + cfg.HostName = value + } + case "user": + if matchedHost { + cfg.User = value + } + case "port": + if matchedHost { + cfg.Port = value + } + case "identityfile": + if matchedHost { + cfg.IdentityFile = value + } + } + } + + if err := scanner.Err(); err != nil { + return nil, err + } + + if cfg.Host == "" { + return nil, fmt.Errorf("host %s not found in SSH config", host) + } + + return cfg, nil +} + +// sshAgent returns an auth method using SSH agent +func sshAgent() (ssh.AuthMethod, error) { + socket := os.Getenv("SSH_AUTH_SOCK") + if socket == "" { + return nil, fmt.Errorf("SSH_AUTH_SOCK not set") + } + + conn, err := net.Dial("unix", socket) + if err != nil { + return nil, fmt.Errorf("failed to connect to SSH agent: %w", err) + } + + agentClient := agent.NewClient(conn) + return ssh.PublicKeysCallback(agentClient.Signers), nil +} + +// publicKeyFromFile returns an auth method from a specific private key file +func publicKeyFromFile(keyPath string) (ssh.AuthMethod, error) { + key, err := os.ReadFile(keyPath) + if err != nil { + return nil, err + } + + signer, err := ssh.ParsePrivateKey(key) + if err != nil { + return nil, err + } + + return ssh.PublicKeys(signer), nil +} + +// publicKeyFile returns an auth method using a private key file +func publicKeyFile() (ssh.AuthMethod, error) { + home, err := os.UserHomeDir() + if err != nil { + return nil, err + } + + // Try common key locations + keyPaths := []string{ + filepath.Join(home, ".ssh", "id_rsa"), + filepath.Join(home, ".ssh", "id_ed25519"), + filepath.Join(home, ".ssh", "id_ecdsa"), + } + + for _, keyPath := range keyPaths { + if _, err := os.Stat(keyPath); err == nil { + key, err := os.ReadFile(keyPath) + if err != nil { + continue + } + + signer, err := ssh.ParsePrivateKey(key) + if err != nil { + continue + } + + return ssh.PublicKeys(signer), nil + } + } + + return nil, fmt.Errorf("no SSH private key found") +} -- cgit v1.2.3