Errors and other stuff

This commit is contained in:
0880
2026-01-21 01:55:33 +03:30
parent f91943cde2
commit fe65fafbe0
2 changed files with 80 additions and 20 deletions

View File

@@ -6,6 +6,7 @@ from json.decoder import JSONDecodeError
from pathlib import Path
from typing import Any, Awaitable, Callable, Optional
from . import _errors
from .responses import HTTPResponse, JSONResponse, Response
PR = re.compile(r"\<([a-zA-Z_][a-zA-Z0-9_]*)\>")
@@ -64,23 +65,25 @@ class Request:
async def _default_404_route(request: Request):
return "HTTP/1.1 404 Not Found\r\nContent-Type: text/html\r\n\r\n404 Not Found contact admin".encode(
encoding="utf-8"
)
return HTTPResponse("404 Not Found", status=404)
async def _default_405_route(request: Request):
return "HTTP/1.1 405 Method Not Allowed\r\nContent-Type: text/html\r\n\r\n405 Method Not Allowed".encode(
encoding="utf-8"
)
return HTTPResponse(b"", status=405)
class App:
def __init__(self):
self.routes: dict[
re.Pattern[str], dict[str, Callable[[Request, ...], Awaitable[Response]]]
re.Pattern[str],
dict[
str,
Callable[[Request, ...], Awaitable[Response | None]],
],
] = {}
self.error_routes: dict[int, Callable[[Request, ...], Awaitable[Response]]] = {
self.error_routes: dict[
int, Callable[[Request, ...], Awaitable[Response | None]]
] = {
404: _default_404_route,
405: _default_405_route,
}
@@ -98,10 +101,11 @@ class App:
self,
path: str,
method: str,
func: Callable[[Request, ...], Awaitable[Response]],
func: Callable[[Request, ...], Awaitable[Response | None]],
):
if method not in ["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"]:
raise RuntimeError(f'Invalid method "{method}".')
_errors.error_not_async(func)
pat = self._pattern_to_regex(path)
if pat not in self.routes:
self.routes[pat] = {}
@@ -112,12 +116,13 @@ class App:
def GET(self, path: str):
"""Decorator to register a GET HTTP route."""
def decorator(func: Callable[[Request, ...], Awaitable[Response]]):
def decorator(func: Callable[[Request, ...], Awaitable[Response | None]]):
self._serve(path, "GET", func)
async def wrapper(*args, **kwargs):
res: Response = await func(*args, **kwargs)
res.no_header = True
res = await func(*args, **kwargs)
if isinstance(res, Response):
res.no_header = True
return res
self._serve(path, "HEAD", wrapper)
@@ -128,7 +133,7 @@ class App:
def POST(self, path: str):
"""Decorator to register a POST HTTP route."""
def decorator(func: Callable[[Request, ...], Awaitable[Response]]):
def decorator(func: Callable[[Request, ...], Awaitable[Response | None]]):
self._serve(path, "POST", func)
return func
@@ -137,14 +142,14 @@ class App:
def PUT(self, path: str):
"""Decorator to register a PUT HTTP route."""
def decorator(func: Callable[[Request, ...], Awaitable[Response]]):
def decorator(func: Callable[[Request, ...], Awaitable[Response | None]]):
self._serve(path, "PUT", func)
return func
def DELETE(self, path: str):
"""Decorator to register a DELETE HTTP route."""
def decorator(func: Callable[[Request, ...], Awaitable[Response]]):
def decorator(func: Callable[[Request, ...], Awaitable[Response | None]]):
self._serve(path, "DELETE", func)
return func
@@ -153,7 +158,7 @@ class App:
def PATCH(self, path: str):
"""Decorator to register a PATCH HTTP route."""
def decorator(func: Callable[[Request, ...], Awaitable[Response]]):
def decorator(func: Callable[[Request, ...], Awaitable[Response | None]]):
self._serve(path, "PATCH", func)
return func
@@ -162,7 +167,7 @@ class App:
def HEAD(self, path: str):
"""Decorator to register a HEAD HTTP route."""
def decorator(func: Callable[[Request, ...], Awaitable[Response]]):
def decorator(func: Callable[[Request, ...], Awaitable[Response | None]]):
self._serve(path, "HEAD", func)
return func
@@ -180,17 +185,17 @@ class App:
def resolve(self, path, method) -> tuple[Callable[..., Awaitable[Response]], dict]:
for pattern, route in self.routes.items():
if m := pattern.fullmatch(path):
if method not in route:
if method not in route or method == "websocket":
return self.error_routes[405], {}
return route[method], {
k: urllib.parse.unquote(v) for k, v in m.groupdict().items()
}
} # ty:ignore[invalid-return-type]
return self.error_routes[404], {}
def methods(self, path) -> set[str]:
for pattern, route in self.routes.items():
if pattern.fullmatch(path):
return set(route.keys())
return set(route.keys() - {"websocket"})
return set()
async def handle_client(
@@ -203,6 +208,12 @@ class App:
if not request_line:
return
if not request_line.decode(encoding="utf-8").startswith(
("GET", "POST", "PUT", "HEAD", "DELETE", "PATCH", "OPTIONS")
):
# Probably WebSocket
pass
# Parse request line
parts = request_line.decode(encoding="utf-8").strip().split()
if len(parts) < 3:
@@ -251,6 +262,17 @@ class App:
else:
route, kwargs = self.resolve(path, method)
if (
method == "GET"
and "connection" in headers
and headers.get("connection") == "Upgrade"
and "upgrade" in headers
):
# Upgrade
return Response(
status=426, headers=["Connection: close"], body=b""
).render(self)
response: Response = await route(
request=Request(
method=method, path=path, headers=headers, body=body