diff options
Diffstat (limited to 'internal/ssh/client.go')
| -rw-r--r-- | internal/ssh/client.go | 357 |
1 files changed, 357 insertions, 0 deletions
diff --git a/internal/ssh/client.go b/internal/ssh/client.go new file mode 100644 index 0000000..1cd336c --- /dev/null +++ b/internal/ssh/client.go | |||
| @@ -0,0 +1,357 @@ | |||
| 1 | package ssh | ||
| 2 | |||
| 3 | import ( | ||
| 4 | "bufio" | ||
| 5 | "bytes" | ||
| 6 | "fmt" | ||
| 7 | "net" | ||
| 8 | "os" | ||
| 9 | "os/exec" | ||
| 10 | "path/filepath" | ||
| 11 | "strings" | ||
| 12 | |||
| 13 | "golang.org/x/crypto/ssh" | ||
| 14 | "golang.org/x/crypto/ssh/agent" | ||
| 15 | ) | ||
| 16 | |||
| 17 | // Client represents an SSH connection to a remote host | ||
| 18 | type Client struct { | ||
| 19 | host string | ||
| 20 | client *ssh.Client | ||
| 21 | } | ||
| 22 | |||
| 23 | // sshConfig holds SSH configuration for a host | ||
| 24 | type sshConfig struct { | ||
| 25 | Host string | ||
| 26 | HostName string | ||
| 27 | User string | ||
| 28 | Port string | ||
| 29 | IdentityFile string | ||
| 30 | } | ||
| 31 | |||
| 32 | // Connect establishes an SSH connection to the remote host | ||
| 33 | // Supports both SSH config aliases (e.g., "myserver") and user@host format | ||
| 34 | func Connect(host string) (*Client, error) { | ||
| 35 | var user, addr string | ||
| 36 | var identityFile string | ||
| 37 | |||
| 38 | // Try to read SSH config first | ||
| 39 | cfg, err := readSSHConfig(host) | ||
| 40 | if err == nil && cfg.HostName != "" { | ||
| 41 | // Use SSH config | ||
| 42 | user = cfg.User | ||
| 43 | addr = cfg.HostName | ||
| 44 | if cfg.Port != "" { | ||
| 45 | addr = addr + ":" + cfg.Port | ||
| 46 | } else { | ||
| 47 | addr = addr + ":22" | ||
| 48 | } | ||
| 49 | identityFile = cfg.IdentityFile | ||
| 50 | } else { | ||
| 51 | // Fall back to parsing user@host format | ||
| 52 | parts := strings.SplitN(host, "@", 2) | ||
| 53 | if len(parts) != 2 { | ||
| 54 | return nil, fmt.Errorf("host '%s' not found in SSH config and not in user@host format", host) | ||
| 55 | } | ||
| 56 | user = parts[0] | ||
| 57 | addr = parts[1] | ||
| 58 | |||
| 59 | // Add default port if not specified | ||
| 60 | if !strings.Contains(addr, ":") { | ||
| 61 | addr = addr + ":22" | ||
| 62 | } | ||
| 63 | } | ||
| 64 | |||
| 65 | // Build authentication methods | ||
| 66 | var authMethods []ssh.AuthMethod | ||
| 67 | |||
| 68 | // Try identity file from SSH config first | ||
| 69 | if identityFile != "" { | ||
| 70 | if authMethod, err := publicKeyFromFile(identityFile); err == nil { | ||
| 71 | authMethods = append(authMethods, authMethod) | ||
| 72 | } | ||
| 73 | } | ||
| 74 | |||
| 75 | // Try SSH agent | ||
| 76 | if authMethod, err := sshAgent(); err == nil { | ||
| 77 | authMethods = append(authMethods, authMethod) | ||
| 78 | } | ||
| 79 | |||
| 80 | // Try default key files | ||
| 81 | if authMethod, err := publicKeyFile(); err == nil { | ||
| 82 | authMethods = append(authMethods, authMethod) | ||
| 83 | } | ||
| 84 | |||
| 85 | if len(authMethods) == 0 { | ||
| 86 | return nil, fmt.Errorf("no SSH authentication method available") | ||
| 87 | } | ||
| 88 | |||
| 89 | config := &ssh.ClientConfig{ | ||
| 90 | User: user, | ||
| 91 | Auth: authMethods, | ||
| 92 | HostKeyCallback: ssh.InsecureIgnoreHostKey(), // TODO: Consider using known_hosts | ||
| 93 | } | ||
| 94 | |||
| 95 | client, err := ssh.Dial("tcp", addr, config) | ||
| 96 | if err != nil { | ||
| 97 | return nil, fmt.Errorf("failed to connect to %s: %w", host, err) | ||
| 98 | } | ||
| 99 | |||
| 100 | return &Client{ | ||
| 101 | host: host, | ||
| 102 | client: client, | ||
| 103 | }, nil | ||
| 104 | } | ||
| 105 | |||
| 106 | // Close closes the SSH connection | ||
| 107 | func (c *Client) Close() error { | ||
| 108 | return c.client.Close() | ||
| 109 | } | ||
| 110 | |||
| 111 | // Run executes a command on the remote host and returns the output | ||
| 112 | func (c *Client) Run(cmd string) (string, error) { | ||
| 113 | session, err := c.client.NewSession() | ||
| 114 | if err != nil { | ||
| 115 | return "", err | ||
| 116 | } | ||
| 117 | defer session.Close() | ||
| 118 | |||
| 119 | var stdout, stderr bytes.Buffer | ||
| 120 | session.Stdout = &stdout | ||
| 121 | session.Stderr = &stderr | ||
| 122 | |||
| 123 | if err := session.Run(cmd); err != nil { | ||
| 124 | return "", fmt.Errorf("command failed: %w\nstderr: %s", err, stderr.String()) | ||
| 125 | } | ||
| 126 | |||
| 127 | return stdout.String(), nil | ||
| 128 | } | ||
| 129 | |||
| 130 | // RunSudo executes a command with sudo on the remote host | ||
| 131 | func (c *Client) RunSudo(cmd string) (string, error) { | ||
| 132 | return c.Run("sudo " + cmd) | ||
| 133 | } | ||
| 134 | |||
| 135 | // Upload copies a local file to the remote host using scp | ||
| 136 | func (c *Client) Upload(localPath, remotePath string) error { | ||
| 137 | // Use external scp command for simplicity | ||
| 138 | // Format: scp -o StrictHostKeyChecking=no localPath user@host:remotePath | ||
| 139 | cmd := exec.Command("scp", "-o", "StrictHostKeyChecking=no", localPath, c.host+":"+remotePath) | ||
| 140 | |||
| 141 | var stderr bytes.Buffer | ||
| 142 | cmd.Stderr = &stderr | ||
| 143 | |||
| 144 | if err := cmd.Run(); err != nil { | ||
| 145 | return fmt.Errorf("scp failed: %w\nstderr: %s", err, stderr.String()) | ||
| 146 | } | ||
| 147 | |||
| 148 | return nil | ||
| 149 | } | ||
| 150 | |||
| 151 | // UploadDir copies a local directory to the remote host using rsync | ||
| 152 | func (c *Client) UploadDir(localDir, remoteDir string) error { | ||
| 153 | // Use rsync for directory uploads | ||
| 154 | // Format: rsync -avz --delete localDir/ user@host:remoteDir/ | ||
| 155 | localDir = strings.TrimSuffix(localDir, "/") + "/" | ||
| 156 | remoteDir = strings.TrimSuffix(remoteDir, "/") + "/" | ||
| 157 | |||
| 158 | cmd := exec.Command("rsync", "-avz", "--delete", | ||
| 159 | "-e", "ssh -o StrictHostKeyChecking=no", | ||
| 160 | localDir, c.host+":"+remoteDir) | ||
| 161 | |||
| 162 | var stderr bytes.Buffer | ||
| 163 | cmd.Stderr = &stderr | ||
| 164 | |||
| 165 | if err := cmd.Run(); err != nil { | ||
| 166 | return fmt.Errorf("rsync failed: %w\nstderr: %s", err, stderr.String()) | ||
| 167 | } | ||
| 168 | |||
| 169 | return nil | ||
| 170 | } | ||
| 171 | |||
| 172 | // WriteFile creates a file with the given content on the remote host | ||
| 173 | func (c *Client) WriteFile(remotePath, content string) error { | ||
| 174 | session, err := c.client.NewSession() | ||
| 175 | if err != nil { | ||
| 176 | return err | ||
| 177 | } | ||
| 178 | defer session.Close() | ||
| 179 | |||
| 180 | // Use cat to write content to file | ||
| 181 | cmd := fmt.Sprintf("cat > %s", remotePath) | ||
| 182 | session.Stdin = strings.NewReader(content) | ||
| 183 | |||
| 184 | var stderr bytes.Buffer | ||
| 185 | session.Stderr = &stderr | ||
| 186 | |||
| 187 | if err := session.Run(cmd); err != nil { | ||
| 188 | return fmt.Errorf("write file failed: %w\nstderr: %s", err, stderr.String()) | ||
| 189 | } | ||
| 190 | |||
| 191 | return nil | ||
| 192 | } | ||
| 193 | |||
| 194 | // WriteSudoFile creates a file with the given content using sudo | ||
| 195 | func (c *Client) WriteSudoFile(remotePath, content string) error { | ||
| 196 | session, err := c.client.NewSession() | ||
| 197 | if err != nil { | ||
| 198 | return err | ||
| 199 | } | ||
| 200 | defer session.Close() | ||
| 201 | |||
| 202 | // Use sudo tee to write content to file | ||
| 203 | cmd := fmt.Sprintf("sudo tee %s > /dev/null", remotePath) | ||
| 204 | session.Stdin = strings.NewReader(content) | ||
| 205 | |||
| 206 | var stderr bytes.Buffer | ||
| 207 | session.Stderr = &stderr | ||
| 208 | |||
| 209 | if err := session.Run(cmd); err != nil { | ||
| 210 | return fmt.Errorf("write file with sudo failed: %w\nstderr: %s", err, stderr.String()) | ||
| 211 | } | ||
| 212 | |||
| 213 | return nil | ||
| 214 | } | ||
| 215 | |||
| 216 | // readSSHConfig reads and parses the SSH config file for a given host | ||
| 217 | func readSSHConfig(host string) (*sshConfig, error) { | ||
| 218 | home, err := os.UserHomeDir() | ||
| 219 | if err != nil { | ||
| 220 | return nil, err | ||
| 221 | } | ||
| 222 | |||
| 223 | configPath := filepath.Join(home, ".ssh", "config") | ||
| 224 | file, err := os.Open(configPath) | ||
| 225 | if err != nil { | ||
| 226 | return nil, err | ||
| 227 | } | ||
| 228 | defer file.Close() | ||
| 229 | |||
| 230 | cfg := &sshConfig{} | ||
| 231 | var currentHost string | ||
| 232 | var matchedHost bool | ||
| 233 | |||
| 234 | scanner := bufio.NewScanner(file) | ||
| 235 | for scanner.Scan() { | ||
| 236 | line := strings.TrimSpace(scanner.Text()) | ||
| 237 | |||
| 238 | // Skip comments and empty lines | ||
| 239 | if line == "" || strings.HasPrefix(line, "#") { | ||
| 240 | continue | ||
| 241 | } | ||
| 242 | |||
| 243 | fields := strings.Fields(line) | ||
| 244 | if len(fields) < 2 { | ||
| 245 | continue | ||
| 246 | } | ||
| 247 | |||
| 248 | key := strings.ToLower(fields[0]) | ||
| 249 | value := fields[1] | ||
| 250 | |||
| 251 | // Expand ~ in paths | ||
| 252 | if strings.HasPrefix(value, "~/") { | ||
| 253 | value = filepath.Join(home, value[2:]) | ||
| 254 | } | ||
| 255 | |||
| 256 | switch key { | ||
| 257 | case "host": | ||
| 258 | currentHost = value | ||
| 259 | if currentHost == host { | ||
| 260 | matchedHost = true | ||
| 261 | cfg.Host = host | ||
| 262 | } else { | ||
| 263 | matchedHost = false | ||
| 264 | } | ||
| 265 | case "hostname": | ||
| 266 | if matchedHost { | ||
| 267 | cfg.HostName = value | ||
| 268 | } | ||
| 269 | case "user": | ||
| 270 | if matchedHost { | ||
| 271 | cfg.User = value | ||
| 272 | } | ||
| 273 | case "port": | ||
| 274 | if matchedHost { | ||
| 275 | cfg.Port = value | ||
| 276 | } | ||
| 277 | case "identityfile": | ||
| 278 | if matchedHost { | ||
| 279 | cfg.IdentityFile = value | ||
| 280 | } | ||
| 281 | } | ||
| 282 | } | ||
| 283 | |||
| 284 | if err := scanner.Err(); err != nil { | ||
| 285 | return nil, err | ||
| 286 | } | ||
| 287 | |||
| 288 | if cfg.Host == "" { | ||
| 289 | return nil, fmt.Errorf("host %s not found in SSH config", host) | ||
| 290 | } | ||
| 291 | |||
| 292 | return cfg, nil | ||
| 293 | } | ||
| 294 | |||
| 295 | // sshAgent returns an auth method using SSH agent | ||
| 296 | func sshAgent() (ssh.AuthMethod, error) { | ||
| 297 | socket := os.Getenv("SSH_AUTH_SOCK") | ||
| 298 | if socket == "" { | ||
| 299 | return nil, fmt.Errorf("SSH_AUTH_SOCK not set") | ||
| 300 | } | ||
| 301 | |||
| 302 | conn, err := net.Dial("unix", socket) | ||
| 303 | if err != nil { | ||
| 304 | return nil, fmt.Errorf("failed to connect to SSH agent: %w", err) | ||
| 305 | } | ||
| 306 | |||
| 307 | agentClient := agent.NewClient(conn) | ||
| 308 | return ssh.PublicKeysCallback(agentClient.Signers), nil | ||
| 309 | } | ||
| 310 | |||
| 311 | // publicKeyFromFile returns an auth method from a specific private key file | ||
| 312 | func publicKeyFromFile(keyPath string) (ssh.AuthMethod, error) { | ||
| 313 | key, err := os.ReadFile(keyPath) | ||
| 314 | if err != nil { | ||
| 315 | return nil, err | ||
| 316 | } | ||
| 317 | |||
| 318 | signer, err := ssh.ParsePrivateKey(key) | ||
| 319 | if err != nil { | ||
| 320 | return nil, err | ||
| 321 | } | ||
| 322 | |||
| 323 | return ssh.PublicKeys(signer), nil | ||
| 324 | } | ||
| 325 | |||
| 326 | // publicKeyFile returns an auth method using a private key file | ||
| 327 | func publicKeyFile() (ssh.AuthMethod, error) { | ||
| 328 | home, err := os.UserHomeDir() | ||
| 329 | if err != nil { | ||
| 330 | return nil, err | ||
| 331 | } | ||
| 332 | |||
| 333 | // Try common key locations | ||
| 334 | keyPaths := []string{ | ||
| 335 | filepath.Join(home, ".ssh", "id_rsa"), | ||
| 336 | filepath.Join(home, ".ssh", "id_ed25519"), | ||
| 337 | filepath.Join(home, ".ssh", "id_ecdsa"), | ||
| 338 | } | ||
| 339 | |||
| 340 | for _, keyPath := range keyPaths { | ||
| 341 | if _, err := os.Stat(keyPath); err == nil { | ||
| 342 | key, err := os.ReadFile(keyPath) | ||
| 343 | if err != nil { | ||
| 344 | continue | ||
| 345 | } | ||
| 346 | |||
| 347 | signer, err := ssh.ParsePrivateKey(key) | ||
| 348 | if err != nil { | ||
| 349 | continue | ||
| 350 | } | ||
| 351 | |||
| 352 | return ssh.PublicKeys(signer), nil | ||
| 353 | } | ||
| 354 | } | ||
| 355 | |||
| 356 | return nil, fmt.Errorf("no SSH private key found") | ||
| 357 | } | ||
