diff --git a/net/lwip/wget.c b/net/lwip/wget.c index 4add520045f..53c3b169e01 100644 --- a/net/lwip/wget.c +++ b/net/lwip/wget.c @@ -34,6 +34,18 @@ struct wget_ctx { enum done_state done; }; +static void wget_lwip_fill_info(struct pbuf *hdr, u16_t hdr_len, u32_t hdr_cont_len) +{ + if (wget_info->headers && hdr_len < MAX_HTTP_HEADERS_SIZE) + pbuf_copy_partial(hdr, (void *)wget_info->headers, hdr_len, 0); + wget_info->hdr_cont_len = (u32)hdr_cont_len; +} + +static void wget_lwip_set_file_size(u32_t rx_content_len) +{ + wget_info->file_size = (ulong)rx_content_len; +} + static int parse_url(char *url, char *host, u16 *port, char **path) { char *p, *pp; @@ -178,6 +190,13 @@ static void httpc_result_cb(void *arg, httpc_result_t httpc_result, struct wget_ctx *ctx = arg; ulong elapsed; + wget_info->status_code = (u32)srv_res; + + if (err == ERR_BUF) { + ctx->done = FAILURE; + return; + } + if (httpc_result != HTTPC_RESULT_OK) { log_err("\nHTTP client error %d\n", httpc_result); ctx->done = FAILURE; @@ -197,8 +216,11 @@ static void httpc_result_cb(void *arg, httpc_result_t httpc_result, printf("%u bytes transferred in %lu ms (", rx_content_len, elapsed); print_size(rx_content_len / elapsed * 1000, "/s)\n"); printf("Bytes transferred = %lu (%lx hex)\n", ctx->size, ctx->size); - efi_set_bootdev("Net", "", ctx->path, map_sysmem(ctx->saved_daddr, 0), - rx_content_len); + if (wget_info->set_bootdev) { + efi_set_bootdev("Net", "", ctx->path, map_sysmem(ctx->saved_daddr, 0), + rx_content_len); + } + wget_lwip_set_file_size(rx_content_len); if (env_set_hex("filesize", rx_content_len) || env_set_hex("fileaddr", ctx->saved_daddr)) { log_err("Could not set filesize or fileaddr\n"); @@ -209,6 +231,17 @@ static void httpc_result_cb(void *arg, httpc_result_t httpc_result, ctx->done = SUCCESS; } +static err_t httpc_headers_done_cb(httpc_state_t *connection, void *arg, struct pbuf *hdr, + u16_t hdr_len, u32_t content_len) +{ + wget_lwip_fill_info(hdr, hdr_len, content_len); + + if (wget_info->check_buffer_size && (ulong)content_len > wget_info->buffer_size) + return ERR_BUF; + + return ERR_OK; +} + static int wget_loop(struct udevice *udev, ulong dst_addr, char *uri) { httpc_connection_t conn; @@ -233,6 +266,7 @@ static int wget_loop(struct udevice *udev, ulong dst_addr, char *uri) memset(&conn, 0, sizeof(conn)); conn.result_fn = httpc_result_cb; + conn.headers_done_fn = httpc_headers_done_cb; ctx.path = path; if (httpc_get_file_dns(ctx.server_name, ctx.port, path, &conn, httpc_recv_cb, &ctx, &state)) { @@ -259,6 +293,9 @@ int wget_with_dns(ulong dst_addr, char *uri) { eth_set_current(); + if (!wget_info) + wget_info = &default_wget_info; + return wget_loop(eth_get_dev(), dst_addr, uri); } @@ -285,6 +322,7 @@ int do_wget(struct cmd_tbl *cmdtp, int flag, int argc, char * const argv[]) if (parse_legacy_arg(url, nurl, sizeof(nurl))) return CMD_RET_FAILURE; + wget_info = &default_wget_info; if (wget_with_dns(dst_addr, nurl)) return CMD_RET_FAILURE;