server: Parse port numbers from MCP server URLs in CORS proxy (#20208)
* Parse port numbers from MCP server URLs * Pass scheme to http proxy for determining whether to use SSL * Fix download on non-standard port and re-add port to logging * add test --------- Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
This commit is contained in:
parent
e22cd0aa15
commit
23fbfcb1ad
6 changed files with 67 additions and 4 deletions
|
|
@ -7,6 +7,7 @@ struct common_http_url {
|
||||||
std::string user;
|
std::string user;
|
||||||
std::string password;
|
std::string password;
|
||||||
std::string host;
|
std::string host;
|
||||||
|
int port;
|
||||||
std::string path;
|
std::string path;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -47,6 +48,20 @@ static common_http_url common_http_parse_url(const std::string & url) {
|
||||||
parts.host = rest;
|
parts.host = rest;
|
||||||
parts.path = "/";
|
parts.path = "/";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto colon_pos = parts.host.find(':');
|
||||||
|
|
||||||
|
if (colon_pos != std::string::npos) {
|
||||||
|
parts.port = std::stoi(parts.host.substr(colon_pos + 1));
|
||||||
|
parts.host = parts.host.substr(0, colon_pos);
|
||||||
|
} else if (parts.scheme == "http") {
|
||||||
|
parts.port = 80;
|
||||||
|
} else if (parts.scheme == "https") {
|
||||||
|
parts.port = 443;
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("unsupported URL scheme: " + parts.scheme);
|
||||||
|
}
|
||||||
|
|
||||||
return parts;
|
return parts;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -68,7 +83,7 @@ static std::pair<httplib::Client, common_http_url> common_http_client(const std:
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
httplib::Client cli(parts.scheme + "://" + parts.host);
|
httplib::Client cli(parts.scheme + "://" + parts.host + ":" + std::to_string(parts.port));
|
||||||
|
|
||||||
if (!parts.user.empty()) {
|
if (!parts.user.empty()) {
|
||||||
cli.set_basic_auth(parts.user, parts.password);
|
cli.set_basic_auth(parts.user, parts.password);
|
||||||
|
|
|
||||||
|
|
@ -30,12 +30,13 @@ static server_http_res_ptr proxy_request(const server_http_req & req, std::strin
|
||||||
throw std::runtime_error("unsupported URL scheme in target URL: " + parsed_url.scheme);
|
throw std::runtime_error("unsupported URL scheme in target URL: " + parsed_url.scheme);
|
||||||
}
|
}
|
||||||
|
|
||||||
SRV_INF("proxying %s request to %s://%s%s\n", method.c_str(), parsed_url.scheme.c_str(), parsed_url.host.c_str(), parsed_url.path.c_str());
|
SRV_INF("proxying %s request to %s://%s:%i%s\n", method.c_str(), parsed_url.scheme.c_str(), parsed_url.host.c_str(), parsed_url.port, parsed_url.path.c_str());
|
||||||
|
|
||||||
auto proxy = std::make_unique<server_http_proxy>(
|
auto proxy = std::make_unique<server_http_proxy>(
|
||||||
method,
|
method,
|
||||||
|
parsed_url.scheme,
|
||||||
parsed_url.host,
|
parsed_url.host,
|
||||||
parsed_url.scheme == "http" ? 80 : 443,
|
parsed_url.port,
|
||||||
parsed_url.path,
|
parsed_url.path,
|
||||||
req.headers,
|
req.headers,
|
||||||
req.body,
|
req.body,
|
||||||
|
|
|
||||||
|
|
@ -783,6 +783,7 @@ server_http_res_ptr server_models::proxy_request(const server_http_req & req, co
|
||||||
}
|
}
|
||||||
auto proxy = std::make_unique<server_http_proxy>(
|
auto proxy = std::make_unique<server_http_proxy>(
|
||||||
method,
|
method,
|
||||||
|
"http",
|
||||||
CHILD_ADDR,
|
CHILD_ADDR,
|
||||||
meta->port,
|
meta->port,
|
||||||
proxy_path,
|
proxy_path,
|
||||||
|
|
@ -1079,6 +1080,7 @@ static bool should_strip_proxy_header(const std::string & header_name) {
|
||||||
|
|
||||||
server_http_proxy::server_http_proxy(
|
server_http_proxy::server_http_proxy(
|
||||||
const std::string & method,
|
const std::string & method,
|
||||||
|
const std::string & scheme,
|
||||||
const std::string & host,
|
const std::string & host,
|
||||||
int port,
|
int port,
|
||||||
const std::string & path,
|
const std::string & path,
|
||||||
|
|
@ -1092,7 +1094,7 @@ server_http_proxy::server_http_proxy(
|
||||||
auto cli = std::make_shared<httplib::ClientImpl>(host, port);
|
auto cli = std::make_shared<httplib::ClientImpl>(host, port);
|
||||||
auto pipe = std::make_shared<pipe_t<msg_t>>();
|
auto pipe = std::make_shared<pipe_t<msg_t>>();
|
||||||
|
|
||||||
if (port == 443) {
|
if (scheme == "https") {
|
||||||
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
||||||
cli.reset(new httplib::SSLClient(host, port));
|
cli.reset(new httplib::SSLClient(host, port));
|
||||||
#else
|
#else
|
||||||
|
|
|
||||||
|
|
@ -180,6 +180,7 @@ struct server_http_proxy : server_http_res {
|
||||||
std::function<void()> cleanup = nullptr;
|
std::function<void()> cleanup = nullptr;
|
||||||
public:
|
public:
|
||||||
server_http_proxy(const std::string & method,
|
server_http_proxy(const std::string & method,
|
||||||
|
const std::string & scheme,
|
||||||
const std::string & host,
|
const std::string & host,
|
||||||
int port,
|
int port,
|
||||||
const std::string & path,
|
const std::string & path,
|
||||||
|
|
|
||||||
41
tools/server/tests/unit/test_proxy.py
Normal file
41
tools/server/tests/unit/test_proxy.py
Normal file
|
|
@ -0,0 +1,41 @@
|
||||||
|
import pytest
|
||||||
|
from utils import *
|
||||||
|
|
||||||
|
server = ServerPreset.tinyllama2()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def create_server():
|
||||||
|
global server
|
||||||
|
server = ServerPreset.tinyllama2()
|
||||||
|
|
||||||
|
|
||||||
|
def test_mcp_no_proxy():
|
||||||
|
global server
|
||||||
|
server.webui_mcp_proxy = False
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
res = server.make_request("GET", "/cors-proxy")
|
||||||
|
assert res.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_mcp_proxy():
|
||||||
|
global server
|
||||||
|
server.webui_mcp_proxy = True
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
url = f"http://{server.server_host}:{server.server_port}/cors-proxy?url=http://example.com"
|
||||||
|
res = requests.get(url)
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert "Example Domain" in res.text
|
||||||
|
|
||||||
|
|
||||||
|
def test_mcp_proxy_custom_port():
|
||||||
|
global server
|
||||||
|
server.webui_mcp_proxy = True
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
# try getting the server's models API via the proxy
|
||||||
|
res = server.make_request("GET", f"/cors-proxy?url=http://{server.server_host}:{server.server_port}/models")
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert "data" in res.body
|
||||||
|
|
@ -102,6 +102,7 @@ class ServerProcess:
|
||||||
mmproj_url: str | None = None
|
mmproj_url: str | None = None
|
||||||
media_path: str | None = None
|
media_path: str | None = None
|
||||||
sleep_idle_seconds: int | None = None
|
sleep_idle_seconds: int | None = None
|
||||||
|
webui_mcp_proxy: bool = False
|
||||||
|
|
||||||
# session variables
|
# session variables
|
||||||
process: subprocess.Popen | None = None
|
process: subprocess.Popen | None = None
|
||||||
|
|
@ -236,6 +237,8 @@ class ServerProcess:
|
||||||
server_args.extend(["--media-path", self.media_path])
|
server_args.extend(["--media-path", self.media_path])
|
||||||
if self.sleep_idle_seconds is not None:
|
if self.sleep_idle_seconds is not None:
|
||||||
server_args.extend(["--sleep-idle-seconds", self.sleep_idle_seconds])
|
server_args.extend(["--sleep-idle-seconds", self.sleep_idle_seconds])
|
||||||
|
if self.webui_mcp_proxy:
|
||||||
|
server_args.append("--webui-mcp-proxy")
|
||||||
|
|
||||||
args = [str(arg) for arg in [server_path, *server_args]]
|
args = [str(arg) for arg in [server_path, *server_args]]
|
||||||
print(f"tests: starting server with: {' '.join(args)}")
|
print(f"tests: starting server with: {' '.join(args)}")
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue