summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sys/src/cmd/ssh.c333
1 files changed, 175 insertions, 158 deletions
diff --git a/sys/src/cmd/ssh.c b/sys/src/cmd/ssh.c
index 34b53099e..b85e6ec48 100644
--- a/sys/src/cmd/ssh.c
+++ b/sys/src/cmd/ssh.c
@@ -44,6 +44,7 @@ enum {
typedef struct
{
+ int pid;
u32int seq;
u32int kex;
Chachastate cs1;
@@ -59,19 +60,18 @@ typedef struct
int nsid;
uchar sid[256];
-int fd, pid1, pid2, intr, raw, debug;
+int fd, intr, raw, debug;
char *user, *status, *host, *cmd;
Oneway recv, send;
+void dispatch(void);
void
shutdown(void)
{
- int pid = getpid();
- if(pid1 && pid1 != pid)
- postnote(PNPROC, pid1, "shutdown");
- if(pid2 && pid2 != pid)
- postnote(PNPROC, pid2, "shutdown");
+ recv.eof = send.eof = 1;
+ if(send.pid > 0)
+ postnote(PNPROC, send.pid, "shutdown");
}
void
@@ -353,35 +353,6 @@ if(debug > 1)
return recv.r[0];
}
-void
-unexpected(char *info)
-{
- char *s;
- int n, c;
-
- switch(recv.r[0]){
- case MSG_DISCONNECT:
- if(unpack(recv.r, recv.w-recv.r, "_us", &c, &s, &n) < 0)
- break;
- sysfatal("disconnect: (%d) %.*s", c, n, s);
- break;
- case MSG_IGNORE:
- case MSG_GLOBAL_REQUEST:
- return;
- case MSG_DEBUG:
- if(unpack(recv.r, recv.w-recv.r, "__sb", &s, &n, &c) < 0)
- break;
- if(c != 0) fprint(2, "%s: %.*s\n", argv0, n, s);
- return;
- case MSG_USERAUTH_BANNER:
- if(unpack(recv.r, recv.w-recv.r, "_s", &s, &n) < 0)
- break;
- if(raw) write(2, s, n);
- return;
- }
- sysfatal("%s got: %.*H", info, (int)(recv.w - recv.r), recv.r);
-}
-
static char sshrsa[] = "ssh-rsa";
int
@@ -538,7 +509,7 @@ kex(int gotkexinit)
if(!gotkexinit){
Next0: switch(recvpkt()){
default:
- unexpected("KEXINIT");
+ dispatch();
goto Next0;
case MSG_KEXINIT:
break;
@@ -570,8 +541,10 @@ kex(int gotkexinit)
sendpkt("bs", MSG_ECDH_INIT, yc, sizeof(yc));
Next1: switch(recvpkt()){
default:
- unexpected("ECDH_INIT");
+ dispatch();
goto Next1;
+ case MSG_KEXINIT:
+ sysfatal("inception");
case MSG_ECDH_REPLY:
if(unpack(recv.r, recv.w-recv.r, "_sss", &ks, &nks, &ys, &nys, &sig, &nsig) < 0)
sysfatal("bad ECDH_REPLY");
@@ -607,8 +580,10 @@ Next1: switch(recvpkt()){
sendpkt("b", MSG_NEWKEYS);
Next2: switch(recvpkt()){
default:
- unexpected("NEWKEYS");
+ dispatch();
goto Next2;
+ case MSG_KEXINIT:
+ sysfatal("inception");
case MSG_NEWKEYS:
break;
}
@@ -647,7 +622,7 @@ auth(char *username, char *servicename)
sendpkt("bs", MSG_SERVICE_REQUEST, sshuserauth, sizeof(sshuserauth)-1);
Next0: switch(recvpkt()){
default:
- unexpected("SERVICE_REQUEST");
+ dispatch();
goto Next0;
case MSG_SERVICE_ACCEPT:
break;
@@ -690,7 +665,7 @@ Next0: switch(recvpkt()){
pk, npk);
Next1: switch(recvpkt()){
default:
- unexpected("USERAUTH_REQUEST");
+ dispatch();
goto Next1;
case MSG_USERAUTH_FAILURE:
continue;
@@ -733,7 +708,7 @@ Next1: switch(recvpkt()){
sig, nsig);
Next2: switch(recvpkt()){
default:
- unexpected("USERAUTH_REQUEST");
+ dispatch();
goto Next2;
case MSG_USERAUTH_FAILURE:
continue;
@@ -751,6 +726,83 @@ Next2: switch(recvpkt()){
return -1;
}
+void
+dispatch(void)
+{
+ char *s;
+ uchar *p;
+ int n, b, c;
+
+ switch(recv.r[0]){
+ case MSG_IGNORE:
+ case MSG_GLOBAL_REQUEST:
+ case MSG_CHANNEL_WINDOW_ADJUST:
+ return;
+ case MSG_DISCONNECT:
+ if(unpack(recv.r, recv.w-recv.r, "_us", &c, &s, &n) < 0)
+ break;
+ sysfatal("disconnect: (%d) %.*s", c, n, s);
+ return;
+ case MSG_DEBUG:
+ if(unpack(recv.r, recv.w-recv.r, "__sb", &s, &n, &c) < 0)
+ break;
+ if(c != 0 || debug) fprint(2, "%s: %.*s\n", argv0, n, s);
+ return;
+ case MSG_USERAUTH_BANNER:
+ if(unpack(recv.r, recv.w-recv.r, "_s", &s, &n) < 0)
+ break;
+ if(raw) write(2, s, n);
+ return;
+ case MSG_CHANNEL_DATA:
+ if(unpack(recv.r, recv.w-recv.r, "_us", &c, &s, &n) < 0)
+ break;
+ if(c != 0)
+ break;
+ if(write(1, s, n) != n)
+ sysfatal("write out: %r");
+ Winadjust:
+ sendpkt("buu", MSG_CHANNEL_WINDOW_ADJUST, c, n);
+ return;
+ case MSG_CHANNEL_EXTENDED_DATA:
+ if(unpack(recv.r, recv.w-recv.r, "_uus", &c, &b, &s, &n) < 0)
+ break;
+ if(c != 0)
+ break;
+ if(b == 1) write(2, s, n);
+ goto Winadjust;
+ case MSG_CHANNEL_REQUEST:
+ if(unpack(recv.r, recv.w-recv.r, "_usb.", &c, &s, &n, &b, &p) < 0)
+ break;
+ if(c != 0)
+ break;
+ if(n == 11 && memcmp(s, "exit-signal", n) == 0){
+ if(unpack(p, recv.w-p, "s", &s, &n) < 0)
+ break;
+ if(n != 0 && status == nil)
+ status = smprint("%.*s", n, s);
+ } else if(n == 11 && memcmp(s, "exit-status", n) == 0){
+ if(unpack(p, recv.w-p, "u", &n) < 0)
+ break;
+ if(n != 0 && status == nil)
+ status = smprint("%d", n);
+ } else if(debug) {
+ fprint(2, "%s: channel request: %.*s\n", argv0, n, s);
+ }
+ return;
+ case MSG_CHANNEL_EOF:
+ recv.eof = 1;
+ if(!raw) write(1, "", 0);
+ return;
+ case MSG_CHANNEL_CLOSE:
+ shutdown();
+ return;
+ case MSG_KEXINIT:
+ kex(1);
+ return;
+ }
+ sysfatal("got: %.*H", (int)(recv.w - recv.r), recv.r);
+}
+
char*
readline(void)
{
@@ -830,7 +882,6 @@ main(int argc, char *argv[])
static QLock sl;
int b, n, c;
char *s;
- uchar *p;
quotefmtinstall();
fmtinstall('B', mpfmt);
@@ -889,7 +940,6 @@ main(int argc, char *argv[])
recv.v = strdup(recv.v);
kex(0);
-
if(user == nil)
user = getuser();
if(auth(user, "ssh-connection") < 0)
@@ -902,125 +952,92 @@ main(int argc, char *argv[])
sizeof(buf),
sizeof(buf));
- while((send.eof | recv.eof) == 0){
- if((int)(send.kex - send.seq) <= 0 || (int)(recv.kex - recv.seq) <= 0){
- qlock(&sl);
- kex(0);
+Next0: switch(recvpkt()){
+ default:
+ dispatch();
+ goto Next0;
+ case MSG_CHANNEL_OPEN_FAILURE:
+ if(unpack(recv.r, recv.w-recv.r, "_uus", &c, &b, &s, &n) < 0)
+ n = strlen(s = "???");
+ sysfatal("channel open failure: (%d) %.*s", b, n, s);
+ case MSG_CHANNEL_OPEN_CONFIRMATION:
+ break;
+ }
+
+ notify(catch);
+ atexit(shutdown);
+
+ recv.pid = getpid();
+ n = rfork(RFPROC|RFMEM);
+ if(n < 0)
+ sysfatal("fork: %r");
+
+ /* parent reads and dispatches packets */
+ if(n > 0) {
+ send.pid = n;
+ while((send.eof|recv.eof) == 0){
+ recvpkt();
+ qlock(&sl);
+ dispatch();
+ if((int)(send.kex - send.seq) <= 0 || (int)(recv.kex - recv.seq) <= 0)
+ kex(0);
qunlock(&sl);
}
- switch(recvpkt()){
- default:
- unexpected("CHANNEL");
- continue;
- case MSG_KEXINIT:
- qlock(&sl);
- kex(1);
- qunlock(&sl);
- continue;
- case MSG_CHANNEL_WINDOW_ADJUST:
- continue;
- case MSG_CHANNEL_EXTENDED_DATA:
- if(unpack(recv.r, recv.w-recv.r, "_uus", &c, &b, &s, &n) < 0)
- unexpected("CHANNEL_EXTENDED_DATA");
- if(b == 1) write(2, s, n);
- sendpkt("buu", MSG_CHANNEL_WINDOW_ADJUST, c, n);
- continue;
- case MSG_CHANNEL_DATA:
- if(unpack(recv.r, recv.w-recv.r, "_us", &c, &s, &n) < 0)
- unexpected("CHANNEL_DATA");
- write(1, s, n);
- sendpkt("buu", MSG_CHANNEL_WINDOW_ADJUST, c, n);
- continue;
- case MSG_CHANNEL_EOF:
- recv.eof = 1;
- if(!raw) write(1, "", 0);
- continue;
- case MSG_CHANNEL_OPEN_FAILURE:
- if(unpack(recv.r, recv.w-recv.r, "_uus", &c, &b, &s, &n) < 0)
- unexpected("CHANNEL_OPEN_FAILURE");
- sysfatal("channel open failure: (%d) %.*s", b, n, s);
- break;
- case MSG_CHANNEL_OPEN_CONFIRMATION:
- if(raw) {
- rawon();
- sendpkt("busbsuuuus", MSG_CHANNEL_REQUEST,
- 0,
- "pty-req", 7,
- 0,
- tty.term, strlen(tty.term),
- tty.cols,
- tty.lines,
- tty.xpixels,
- tty.ypixels,
- "", 0);
- }
- if(cmd == nil){
- sendpkt("busb", MSG_CHANNEL_REQUEST,
- 0,
- "shell", 5,
- 0);
- } else {
- sendpkt("busbs", MSG_CHANNEL_REQUEST,
- 0,
- "exec", 4,
- 0,
- cmd, strlen(cmd));
- }
- if(pid2)
- continue;
- pid1 = getpid();
- notify(catch);
- atexit(shutdown);
- n = rfork(RFPROC|RFMEM);
- if(n){
- pid2 = n;
- continue;
- }
- qlock(&sl);
- for(;;){
- qunlock(&sl);
- n = read(0, buf, sizeof(buf));
- qlock(&sl);
- if(n < 0 && wasintr()){
- sendpkt("busbs", MSG_CHANNEL_REQUEST,
- 0,
- "signal", 6,
- 0,
- "INT", 3);
- intr = 0;
- continue;
- }
- if(n <= 0)
- break;
- sendpkt("bus", MSG_CHANNEL_DATA,
- 0,
- buf, n);
- }
- send.eof = 1;
- sendpkt("bu", raw ? MSG_CHANNEL_CLOSE : MSG_CHANNEL_EOF, 0);
- qunlock(&sl);
+ exits(status);
+ }
+
+ /* child reads input and sends packets */
+ qlock(&sl);
+ if(raw) {
+ rawon();
+ sendpkt("busbsuuuus", MSG_CHANNEL_REQUEST,
+ 0,
+ "pty-req", 7,
+ 0,
+ tty.term, strlen(tty.term),
+ tty.cols,
+ tty.lines,
+ tty.xpixels,
+ tty.ypixels,
+ "", 0);
+ }
+ if(cmd == nil){
+ sendpkt("busb", MSG_CHANNEL_REQUEST,
+ 0,
+ "shell", 5,
+ 0);
+ } else {
+ sendpkt("busbs", MSG_CHANNEL_REQUEST,
+ 0,
+ "exec", 4,
+ 0,
+ cmd, strlen(cmd));
+ }
+ for(;;){
+ qunlock(&sl);
+ n = read(0, buf, sizeof(buf));
+ qlock(&sl);
+ if(send.eof)
break;
- case MSG_CHANNEL_REQUEST:
- if(unpack(recv.r, recv.w-recv.r, "_usb.", &c, &s, &n, &b, &p) < 0)
- unexpected("CHANNEL_REQUEST");
- if(n == 11 && memcmp(s, "exit-signal", n) == 0){
- if(unpack(p, recv.w-p, "s", &s, &n) < 0)
- continue;
- if(n != 0 && status == nil)
- status = smprint("%.*s", n, s);
- } else if(n == 11 && memcmp(s, "exit-status", n) == 0){
- if(unpack(p, recv.w-p, "u", &n) < 0)
- continue;
- if(n != 0 && status == nil)
- status = smprint("%d", n);
- } else {
- fprint(2, "%s: channel request: %.*s\n", argv0, n, s);
- }
+ if(n < 0 && wasintr()){
+ if(!raw) break;
+ sendpkt("busbs", MSG_CHANNEL_REQUEST,
+ 0,
+ "signal", 6,
+ 0,
+ "INT", 3);
+ intr = 0;
continue;
- case MSG_CHANNEL_CLOSE:
- break;
}
- break;
+ if(n <= 0)
+ break;
+ sendpkt("bus", MSG_CHANNEL_DATA,
+ 0,
+ buf, n);
}
- exits(status);
+ if(send.eof++ == 0)
+ sendpkt("bu", raw ? MSG_CHANNEL_CLOSE : MSG_CHANNEL_EOF, 0);
+ qunlock(&sl);
+
+ exits(nil);
}