summaryrefslogtreecommitdiffstats
path: root/internal/ssh/client.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/ssh/client.go')
-rw-r--r--internal/ssh/client.go357
1 files changed, 357 insertions, 0 deletions
diff --git a/internal/ssh/client.go b/internal/ssh/client.go
new file mode 100644
index 0000000..1cd336c
--- /dev/null
+++ b/internal/ssh/client.go
@@ -0,0 +1,357 @@
1package ssh
2
3import (
4 "bufio"
5 "bytes"
6 "fmt"
7 "net"
8 "os"
9 "os/exec"
10 "path/filepath"
11 "strings"
12
13 "golang.org/x/crypto/ssh"
14 "golang.org/x/crypto/ssh/agent"
15)
16
17// Client represents an SSH connection to a remote host
18type Client struct {
19 host string
20 client *ssh.Client
21}
22
23// sshConfig holds SSH configuration for a host
24type sshConfig struct {
25 Host string
26 HostName string
27 User string
28 Port string
29 IdentityFile string
30}
31
32// Connect establishes an SSH connection to the remote host
33// Supports both SSH config aliases (e.g., "myserver") and user@host format
34func Connect(host string) (*Client, error) {
35 var user, addr string
36 var identityFile string
37
38 // Try to read SSH config first
39 cfg, err := readSSHConfig(host)
40 if err == nil && cfg.HostName != "" {
41 // Use SSH config
42 user = cfg.User
43 addr = cfg.HostName
44 if cfg.Port != "" {
45 addr = addr + ":" + cfg.Port
46 } else {
47 addr = addr + ":22"
48 }
49 identityFile = cfg.IdentityFile
50 } else {
51 // Fall back to parsing user@host format
52 parts := strings.SplitN(host, "@", 2)
53 if len(parts) != 2 {
54 return nil, fmt.Errorf("host '%s' not found in SSH config and not in user@host format", host)
55 }
56 user = parts[0]
57 addr = parts[1]
58
59 // Add default port if not specified
60 if !strings.Contains(addr, ":") {
61 addr = addr + ":22"
62 }
63 }
64
65 // Build authentication methods
66 var authMethods []ssh.AuthMethod
67
68 // Try identity file from SSH config first
69 if identityFile != "" {
70 if authMethod, err := publicKeyFromFile(identityFile); err == nil {
71 authMethods = append(authMethods, authMethod)
72 }
73 }
74
75 // Try SSH agent
76 if authMethod, err := sshAgent(); err == nil {
77 authMethods = append(authMethods, authMethod)
78 }
79
80 // Try default key files
81 if authMethod, err := publicKeyFile(); err == nil {
82 authMethods = append(authMethods, authMethod)
83 }
84
85 if len(authMethods) == 0 {
86 return nil, fmt.Errorf("no SSH authentication method available")
87 }
88
89 config := &ssh.ClientConfig{
90 User: user,
91 Auth: authMethods,
92 HostKeyCallback: ssh.InsecureIgnoreHostKey(), // TODO: Consider using known_hosts
93 }
94
95 client, err := ssh.Dial("tcp", addr, config)
96 if err != nil {
97 return nil, fmt.Errorf("failed to connect to %s: %w", host, err)
98 }
99
100 return &Client{
101 host: host,
102 client: client,
103 }, nil
104}
105
106// Close closes the SSH connection
107func (c *Client) Close() error {
108 return c.client.Close()
109}
110
111// Run executes a command on the remote host and returns the output
112func (c *Client) Run(cmd string) (string, error) {
113 session, err := c.client.NewSession()
114 if err != nil {
115 return "", err
116 }
117 defer session.Close()
118
119 var stdout, stderr bytes.Buffer
120 session.Stdout = &stdout
121 session.Stderr = &stderr
122
123 if err := session.Run(cmd); err != nil {
124 return "", fmt.Errorf("command failed: %w\nstderr: %s", err, stderr.String())
125 }
126
127 return stdout.String(), nil
128}
129
130// RunSudo executes a command with sudo on the remote host
131func (c *Client) RunSudo(cmd string) (string, error) {
132 return c.Run("sudo " + cmd)
133}
134
135// Upload copies a local file to the remote host using scp
136func (c *Client) Upload(localPath, remotePath string) error {
137 // Use external scp command for simplicity
138 // Format: scp -o StrictHostKeyChecking=no localPath user@host:remotePath
139 cmd := exec.Command("scp", "-o", "StrictHostKeyChecking=no", localPath, c.host+":"+remotePath)
140
141 var stderr bytes.Buffer
142 cmd.Stderr = &stderr
143
144 if err := cmd.Run(); err != nil {
145 return fmt.Errorf("scp failed: %w\nstderr: %s", err, stderr.String())
146 }
147
148 return nil
149}
150
151// UploadDir copies a local directory to the remote host using rsync
152func (c *Client) UploadDir(localDir, remoteDir string) error {
153 // Use rsync for directory uploads
154 // Format: rsync -avz --delete localDir/ user@host:remoteDir/
155 localDir = strings.TrimSuffix(localDir, "/") + "/"
156 remoteDir = strings.TrimSuffix(remoteDir, "/") + "/"
157
158 cmd := exec.Command("rsync", "-avz", "--delete",
159 "-e", "ssh -o StrictHostKeyChecking=no",
160 localDir, c.host+":"+remoteDir)
161
162 var stderr bytes.Buffer
163 cmd.Stderr = &stderr
164
165 if err := cmd.Run(); err != nil {
166 return fmt.Errorf("rsync failed: %w\nstderr: %s", err, stderr.String())
167 }
168
169 return nil
170}
171
172// WriteFile creates a file with the given content on the remote host
173func (c *Client) WriteFile(remotePath, content string) error {
174 session, err := c.client.NewSession()
175 if err != nil {
176 return err
177 }
178 defer session.Close()
179
180 // Use cat to write content to file
181 cmd := fmt.Sprintf("cat > %s", remotePath)
182 session.Stdin = strings.NewReader(content)
183
184 var stderr bytes.Buffer
185 session.Stderr = &stderr
186
187 if err := session.Run(cmd); err != nil {
188 return fmt.Errorf("write file failed: %w\nstderr: %s", err, stderr.String())
189 }
190
191 return nil
192}
193
194// WriteSudoFile creates a file with the given content using sudo
195func (c *Client) WriteSudoFile(remotePath, content string) error {
196 session, err := c.client.NewSession()
197 if err != nil {
198 return err
199 }
200 defer session.Close()
201
202 // Use sudo tee to write content to file
203 cmd := fmt.Sprintf("sudo tee %s > /dev/null", remotePath)
204 session.Stdin = strings.NewReader(content)
205
206 var stderr bytes.Buffer
207 session.Stderr = &stderr
208
209 if err := session.Run(cmd); err != nil {
210 return fmt.Errorf("write file with sudo failed: %w\nstderr: %s", err, stderr.String())
211 }
212
213 return nil
214}
215
216// readSSHConfig reads and parses the SSH config file for a given host
217func readSSHConfig(host string) (*sshConfig, error) {
218 home, err := os.UserHomeDir()
219 if err != nil {
220 return nil, err
221 }
222
223 configPath := filepath.Join(home, ".ssh", "config")
224 file, err := os.Open(configPath)
225 if err != nil {
226 return nil, err
227 }
228 defer file.Close()
229
230 cfg := &sshConfig{}
231 var currentHost string
232 var matchedHost bool
233
234 scanner := bufio.NewScanner(file)
235 for scanner.Scan() {
236 line := strings.TrimSpace(scanner.Text())
237
238 // Skip comments and empty lines
239 if line == "" || strings.HasPrefix(line, "#") {
240 continue
241 }
242
243 fields := strings.Fields(line)
244 if len(fields) < 2 {
245 continue
246 }
247
248 key := strings.ToLower(fields[0])
249 value := fields[1]
250
251 // Expand ~ in paths
252 if strings.HasPrefix(value, "~/") {
253 value = filepath.Join(home, value[2:])
254 }
255
256 switch key {
257 case "host":
258 currentHost = value
259 if currentHost == host {
260 matchedHost = true
261 cfg.Host = host
262 } else {
263 matchedHost = false
264 }
265 case "hostname":
266 if matchedHost {
267 cfg.HostName = value
268 }
269 case "user":
270 if matchedHost {
271 cfg.User = value
272 }
273 case "port":
274 if matchedHost {
275 cfg.Port = value
276 }
277 case "identityfile":
278 if matchedHost {
279 cfg.IdentityFile = value
280 }
281 }
282 }
283
284 if err := scanner.Err(); err != nil {
285 return nil, err
286 }
287
288 if cfg.Host == "" {
289 return nil, fmt.Errorf("host %s not found in SSH config", host)
290 }
291
292 return cfg, nil
293}
294
295// sshAgent returns an auth method using SSH agent
296func sshAgent() (ssh.AuthMethod, error) {
297 socket := os.Getenv("SSH_AUTH_SOCK")
298 if socket == "" {
299 return nil, fmt.Errorf("SSH_AUTH_SOCK not set")
300 }
301
302 conn, err := net.Dial("unix", socket)
303 if err != nil {
304 return nil, fmt.Errorf("failed to connect to SSH agent: %w", err)
305 }
306
307 agentClient := agent.NewClient(conn)
308 return ssh.PublicKeysCallback(agentClient.Signers), nil
309}
310
311// publicKeyFromFile returns an auth method from a specific private key file
312func publicKeyFromFile(keyPath string) (ssh.AuthMethod, error) {
313 key, err := os.ReadFile(keyPath)
314 if err != nil {
315 return nil, err
316 }
317
318 signer, err := ssh.ParsePrivateKey(key)
319 if err != nil {
320 return nil, err
321 }
322
323 return ssh.PublicKeys(signer), nil
324}
325
326// publicKeyFile returns an auth method using a private key file
327func publicKeyFile() (ssh.AuthMethod, error) {
328 home, err := os.UserHomeDir()
329 if err != nil {
330 return nil, err
331 }
332
333 // Try common key locations
334 keyPaths := []string{
335 filepath.Join(home, ".ssh", "id_rsa"),
336 filepath.Join(home, ".ssh", "id_ed25519"),
337 filepath.Join(home, ".ssh", "id_ecdsa"),
338 }
339
340 for _, keyPath := range keyPaths {
341 if _, err := os.Stat(keyPath); err == nil {
342 key, err := os.ReadFile(keyPath)
343 if err != nil {
344 continue
345 }
346
347 signer, err := ssh.ParsePrivateKey(key)
348 if err != nil {
349 continue
350 }
351
352 return ssh.PublicKeys(signer), nil
353 }
354 }
355
356 return nil, fmt.Errorf("no SSH private key found")
357}