diff --git a/buse.c b/buse.c index f1e4a66..f1552ab 100644 --- a/buse.c +++ b/buse.c @@ -32,12 +32,42 @@ u_int64_t ntohll(u_int64_t a) { #endif #define htonll ntohll +static int read_all(int fd, void *buf, size_t count) +{ + int bytes_read; + + while (count > 0) { + bytes_read = read(fd, buf, count); + assert(bytes_read > 0); + buf = (char *)buf + bytes_read; + count -= bytes_read; + } + assert(count == 0); + + return 0; +} + +static int write_all(int fd, const void *buf, size_t count) +{ + int bytes_written; + + while (count > 0) { + bytes_written = write(fd, buf, count); + assert(bytes_written > 0); + buf = (char *)buf + bytes_written; + count -= bytes_written; + } + assert(count == 0); + + return 0; +} + int buse_main(int argc, char *argv[], const struct buse_operations *aop, void *userdata) { int sp[2]; int nbd, sk, err, tmp_fd; u_int64_t from; - u_int32_t len, bytes_read, bytes_written; + u_int32_t len, bytes_read; char *dev_file; struct nbd_request request; struct nbd_reply reply; @@ -45,8 +75,8 @@ int buse_main(int argc, char *argv[], const struct buse_operations *aop, void *u (void) userdata; - assert(argc == 2); - dev_file = argv[1]; + assert(argc == 3); + dev_file = argv[2]; assert(!socketpair(AF_UNIX, SOCK_STREAM, 0, sp)); @@ -64,6 +94,8 @@ int buse_main(int argc, char *argv[], const struct buse_operations *aop, void *u assert(ioctl(nbd, NBD_SET_SOCK, sk) != -1); err = ioctl(nbd, NBD_DO_IT); fprintf(stderr, "nbd device terminated with code %d\n", err); + if (err == -1) + fprintf(stderr, "%s\n", strerror(errno)); assert(ioctl(nbd, NBD_CLEAR_QUE) != -1); assert(ioctl(nbd, NBD_CLEAR_SOCK) != -1); @@ -105,23 +137,21 @@ int buse_main(int argc, char *argv[], const struct buse_operations *aop, void *u chunk = malloc(len + sizeof(struct nbd_reply)); aop->read((char *)chunk + sizeof(struct nbd_reply), len, from); memcpy(chunk, &reply, sizeof(struct nbd_reply)); - bytes_written = write(sk, chunk, len + sizeof(struct nbd_reply)); - assert(bytes_written == len + sizeof(struct nbd_reply)); + write_all(sk, chunk, len + sizeof(struct nbd_reply)); free(chunk); break; case NBD_CMD_WRITE: fprintf(stderr, "Request for write of size %d\n", len); chunk = malloc(len); - bytes_read = read(sk, chunk, len); - assert(bytes_read == len); + read_all(sk, chunk, len); aop->write(chunk, len, from); free(chunk); - bytes_written = write(sk, &reply, sizeof(struct nbd_reply)); - assert(bytes_written == sizeof(struct nbd_reply)); + write_all(sk, &reply, sizeof(struct nbd_reply)); break; case NBD_CMD_DISC: + /* Handle a disconnect request. */ aop->disc(); - break; + return 0; case NBD_CMD_FLUSH: aop->flush(); break;