diff --git a/lib/vquic/curl_osslq.c b/lib/vquic/curl_osslq.c index a69da56803..c221124d72 100644 --- a/lib/vquic/curl_osslq.c +++ b/lib/vquic/curl_osslq.c @@ -290,7 +290,7 @@ struct cf_osslq_ctx { uint64_t max_idle_ms; /* max idle time for QUIC connection */ SSL_POLL_ITEM *poll_items; /* Array for polling on writable state */ struct Curl_easy **curl_items; /* Array of easy objs */ - size_t item_count; /* count of elements in poll/curl_items */ + size_t items_max; /* max elements in poll/curl_items */ BIT(initialized); BIT(got_first_byte); /* if first byte was received */ BIT(x509_store_setup); /* if x509 store has been set up */ @@ -309,7 +309,7 @@ static void cf_osslq_ctx_init(struct cf_osslq_ctx *ctx) Curl_hash_offt_init(&ctx->streams, 63, h3_stream_hash_free); ctx->poll_items = NULL; ctx->curl_items = NULL; - ctx->item_count = 0; + ctx->items_max = 0; ctx->initialized = TRUE; } @@ -666,6 +666,24 @@ static void h3_data_done(struct Curl_cfilter *cf, struct Curl_easy *data) } } +struct cf_ossq_find_ctx { + curl_int64_t stream_id; + struct h3_stream_ctx *stream; +}; + +static bool cf_osslq_find_stream(curl_off_t mid, void *val, void *user_data) +{ + struct h3_stream_ctx *stream = val; + struct cf_ossq_find_ctx *fctx = user_data; + + (void)mid; + if(stream && stream->s.id == fctx->stream_id) { + fctx->stream = stream; + return FALSE; /* stop iterating */ + } + return TRUE; +} + static struct cf_osslq_stream *cf_osslq_get_qstream(struct Curl_cfilter *cf, struct Curl_easy *data, int64_t stream_id) @@ -686,17 +704,12 @@ static struct cf_osslq_stream *cf_osslq_get_qstream(struct Curl_cfilter *cf, return &ctx->h3.s_qpack_dec; } else { - struct Curl_llist_node *e; - DEBUGASSERT(data->multi); - for(e = Curl_llist_head(&data->multi->process); e; e = Curl_node_next(e)) { - struct Curl_easy *sdata = Curl_node_elem(e); - if(sdata->conn != data->conn) - continue; - stream = H3_STREAM_CTX(ctx, sdata); - if(stream && stream->s.id == stream_id) { - return &stream->s; - } - } + struct cf_ossq_find_ctx fctx; + fctx.stream_id = stream_id; + fctx.stream = NULL; + Curl_hash_offt_visit(&ctx->streams, cf_osslq_find_stream, &fctx); + if(fctx.stream) + return &fctx.stream->s; } return NULL; } @@ -1401,6 +1414,29 @@ out: return result; } +struct cf_ossq_recv_ctx { + struct Curl_cfilter *cf; + struct Curl_multi *multi; + CURLcode result; +}; + +static bool cf_osslq_iter_recv(curl_off_t mid, void *val, void *user_data) +{ + struct h3_stream_ctx *stream = val; + struct cf_ossq_recv_ctx *rctx = user_data; + + (void)mid; + if(stream && !stream->closed && !Curl_bufq_is_full(&stream->recvbuf)) { + struct Curl_easy *sdata = Curl_multi_get_handle(rctx->multi, mid); + if(sdata) { + rctx->result = cf_osslq_stream_recv(&stream->s, rctx->cf, sdata); + if(rctx->result) + return FALSE; /* abort iteration */ + } + } + return TRUE; +} + static CURLcode cf_progress_ingress(struct Curl_cfilter *cf, struct Curl_easy *data) { @@ -1437,22 +1473,14 @@ static CURLcode cf_progress_ingress(struct Curl_cfilter *cf, } if(ctx->h3.conn) { - struct Curl_llist_node *e; - struct h3_stream_ctx *stream; - /* PULL all open streams */ + struct cf_ossq_recv_ctx rctx; + DEBUGASSERT(data->multi); - for(e = Curl_llist_head(&data->multi->process); e; e = Curl_node_next(e)) { - struct Curl_easy *sdata = Curl_node_elem(e); - if(sdata->conn == data->conn && CURL_WANT_RECV(sdata)) { - stream = H3_STREAM_CTX(ctx, sdata); - if(stream && !stream->closed && - !Curl_bufq_is_full(&stream->recvbuf)) { - result = cf_osslq_stream_recv(&stream->s, cf, sdata); - if(result) - goto out; - } - } - } + rctx.cf = cf; + rctx.multi = data->multi; + rctx.result = CURLE_OK; + Curl_hash_offt_visit(&ctx->streams, cf_osslq_iter_recv, &rctx); + result = rctx.result; } out: @@ -1460,13 +1488,43 @@ out: return result; } +struct cf_ossq_fill_ctx { + struct cf_osslq_ctx *ctx; + struct Curl_multi *multi; + size_t n; +}; + +static bool cf_osslq_collect_block_send(curl_off_t mid, void *val, + void *user_data) +{ + struct h3_stream_ctx *stream = val; + struct cf_ossq_fill_ctx *fctx = user_data; + struct cf_osslq_ctx *ctx = fctx->ctx; + + if(fctx->n >= ctx->items_max) /* should not happen, prevent mayhem */ + return FALSE; + + if(stream && stream->s.ssl && stream->s.send_blocked) { + struct Curl_easy *sdata = Curl_multi_get_handle(fctx->multi, mid); + fprintf(stderr, "[OSSLQ] stream %" FMT_PRId64 " sdata=%p\n", + stream->s.id, (void *)sdata); + if(sdata) { + ctx->poll_items[fctx->n].desc = SSL_as_poll_descriptor(stream->s.ssl); + ctx->poll_items[fctx->n].events = SSL_POLL_EVENT_W; + ctx->curl_items[fctx->n] = sdata; + fctx->n++; + } + } + return TRUE; +} + /* Iterate over all streams and check if blocked can be unblocked */ static CURLcode cf_osslq_check_and_unblock(struct Curl_cfilter *cf, struct Curl_easy *data) { struct cf_osslq_ctx *ctx = cf->ctx; struct h3_stream_ctx *stream; - size_t poll_count = 0; + size_t poll_count; size_t result_count = 0; size_t idx_count = 0; CURLcode res = CURLE_OK; @@ -1474,68 +1532,60 @@ static CURLcode cf_osslq_check_and_unblock(struct Curl_cfilter *cf, void *tmpptr; if(ctx->h3.conn) { - struct Curl_llist_node *e; + struct cf_ossq_fill_ctx fill_ctx; - res = CURLE_OUT_OF_MEMORY; - - if(ctx->item_count < Curl_llist_count(&data->multi->process)) { - ctx->item_count = 0; - tmpptr = realloc(ctx->poll_items, - Curl_llist_count(&data->multi->process) * - sizeof(SSL_POLL_ITEM)); + if(ctx->items_max < Curl_hash_offt_count(&ctx->streams)) { + size_t nmax = Curl_hash_offt_count(&ctx->streams); + ctx->items_max = 0; + tmpptr = realloc(ctx->poll_items, nmax * sizeof(SSL_POLL_ITEM)); if(!tmpptr) { free(ctx->poll_items); ctx->poll_items = NULL; + res = CURLE_OUT_OF_MEMORY; goto out; } ctx->poll_items = tmpptr; - tmpptr = realloc(ctx->curl_items, - Curl_llist_count(&data->multi->process) * - sizeof(struct Curl_easy *)); + tmpptr = realloc(ctx->curl_items, nmax * sizeof(struct Curl_easy *)); if(!tmpptr) { free(ctx->curl_items); ctx->curl_items = NULL; + res = CURLE_OUT_OF_MEMORY; goto out; } ctx->curl_items = tmpptr; - - ctx->item_count = Curl_llist_count(&data->multi->process); + ctx->items_max = nmax; } - for(e = Curl_llist_head(&data->multi->process); e; e = Curl_node_next(e)) { - struct Curl_easy *sdata = Curl_node_elem(e); - if(sdata->conn == data->conn) { - stream = H3_STREAM_CTX(ctx, sdata); - if(stream && stream->s.ssl && stream->s.send_blocked) { - ctx->poll_items[poll_count].desc = - SSL_as_poll_descriptor(stream->s.ssl); - ctx->poll_items[poll_count].events = SSL_POLL_EVENT_W; - ctx->curl_items[poll_count] = sdata; - poll_count++; + fill_ctx.ctx = ctx; + fill_ctx.multi = data->multi; + fill_ctx.n = 0; + Curl_hash_offt_visit(&ctx->streams, cf_osslq_collect_block_send, + &fill_ctx); + poll_count = fill_ctx.n; + if(poll_count) { + CURL_TRC_CF(data, cf, "polling %zu blocked streams", poll_count); + + memset(&timeout, 0, sizeof(struct timeval)); + res = CURLE_UNRECOVERABLE_POLL; + if(!SSL_poll(ctx->poll_items, poll_count, sizeof(SSL_POLL_ITEM), + &timeout, 0, &result_count)) + goto out; + + res = CURLE_OK; + + for(idx_count = 0; idx_count < poll_count && result_count > 0; + idx_count++) { + if(ctx->poll_items[idx_count].revents & SSL_POLL_EVENT_W) { + stream = H3_STREAM_CTX(ctx, ctx->curl_items[idx_count]); + nghttp3_conn_unblock_stream(ctx->h3.conn, stream->s.id); + stream->s.send_blocked = FALSE; + h3_drain_stream(cf, ctx->curl_items[idx_count]); + CURL_TRC_CF(ctx->curl_items[idx_count], cf, "unblocked"); + result_count--; } } } - - memset(&timeout, 0, sizeof(struct timeval)); - res = CURLE_UNRECOVERABLE_POLL; - if(!SSL_poll(ctx->poll_items, poll_count, sizeof(SSL_POLL_ITEM), &timeout, - 0, &result_count)) - goto out; - - res = CURLE_OK; - - for(idx_count = 0; idx_count < poll_count && result_count > 0; - idx_count++) { - if(ctx->poll_items[idx_count].revents & SSL_POLL_EVENT_W) { - stream = H3_STREAM_CTX(ctx, ctx->curl_items[idx_count]); - nghttp3_conn_unblock_stream(ctx->h3.conn, stream->s.id); - stream->s.send_blocked = FALSE; - h3_drain_stream(cf, ctx->curl_items[idx_count]); - CURL_TRC_CF(ctx->curl_items[idx_count], cf, "unblocked"); - result_count--; - } - } } out: diff --git a/tests/http/test_14_auth.py b/tests/http/test_14_auth.py index 237d7ecda8..13193b53b8 100644 --- a/tests/http/test_14_auth.py +++ b/tests/http/test_14_auth.py @@ -134,4 +134,4 @@ class TestAuth: # Depending on protocol, we might have an error sending or # the server might shutdown the connection and we see the error # on receiving - assert r.exit_code in [55, 56], f'{r.dump_logs()}' + assert r.exit_code in [55, 56, 95], f'{r.dump_logs()}'