diff --git a/slow/_errors.py b/slow/_errors.py new file mode 100644 index 0000000..55e2800 --- /dev/null +++ b/slow/_errors.py @@ -0,0 +1,38 @@ +import inspect + + +def error_not_async(func): + if not inspect.iscoroutinefunction(func): + lines, start = inspect.getsourcelines(func) + + fdef_index = 0 + while fdef_index < len(lines) and not lines[fdef_index].strip().startswith( + "def" + ): + fdef_index += 1 + + if fdef_index < len(lines): + line_num = start + fdef_index + + fdef_content = lines[fdef_index].strip() + + ERROR_HEADER = "\033[1m\033[91m[FATAL_EXECUTION_ERROR]: Non-Asynchronous Route Detected\033[0m" + + ERROR_BODY = ( + f"\n" + f'--> File "{inspect.getsourcefile(func)}", line {line_num}, in {func.__name__}\n\n' + f"\033[1m[CONSTRAINT_VIOLATION]\033[0m\n" + f" Synchronous function used where an async coroutine was required.\n" + f" \033[93mCode Traceback:\033[0m\n" + f" {line_num}: \033[93m{fdef_content}\033[0m\n\n" + f"\033[1m[SUGGESTED_PATCH]\033[0m\n" + f" Apply the 'async' keyword to the function signature:\n" + f" {line_num}: \033[92masync {fdef_content}\033[0m" + ) + + raise RuntimeError(ERROR_HEADER + ERROR_BODY) + + else: + raise RuntimeError( + "\033[1m\033[91m[FATAL_EXECUTION_ERROR]: Non-Asynchronous Route Detected\033[0m" + ) diff --git a/slow/slow.py b/slow/slow.py index abd4725..4eee55a 100644 --- a/slow/slow.py +++ b/slow/slow.py @@ -6,6 +6,7 @@ from json.decoder import JSONDecodeError from pathlib import Path from typing import Any, Awaitable, Callable, Optional +from . import _errors from .responses import HTTPResponse, JSONResponse, Response PR = re.compile(r"\<([a-zA-Z_][a-zA-Z0-9_]*)\>") @@ -64,23 +65,25 @@ class Request: async def _default_404_route(request: Request): - return "HTTP/1.1 404 Not Found\r\nContent-Type: text/html\r\n\r\n404 Not Found contact admin".encode( - encoding="utf-8" - ) + return HTTPResponse("404 Not Found", status=404) async def _default_405_route(request: Request): - return "HTTP/1.1 405 Method Not Allowed\r\nContent-Type: text/html\r\n\r\n405 Method Not Allowed".encode( - encoding="utf-8" - ) + return HTTPResponse(b"", status=405) class App: def __init__(self): self.routes: dict[ - re.Pattern[str], dict[str, Callable[[Request, ...], Awaitable[Response]]] + re.Pattern[str], + dict[ + str, + Callable[[Request, ...], Awaitable[Response | None]], + ], ] = {} - self.error_routes: dict[int, Callable[[Request, ...], Awaitable[Response]]] = { + self.error_routes: dict[ + int, Callable[[Request, ...], Awaitable[Response | None]] + ] = { 404: _default_404_route, 405: _default_405_route, } @@ -98,10 +101,11 @@ class App: self, path: str, method: str, - func: Callable[[Request, ...], Awaitable[Response]], + func: Callable[[Request, ...], Awaitable[Response | None]], ): if method not in ["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"]: raise RuntimeError(f'Invalid method "{method}".') + _errors.error_not_async(func) pat = self._pattern_to_regex(path) if pat not in self.routes: self.routes[pat] = {} @@ -112,12 +116,13 @@ class App: def GET(self, path: str): """Decorator to register a GET HTTP route.""" - def decorator(func: Callable[[Request, ...], Awaitable[Response]]): + def decorator(func: Callable[[Request, ...], Awaitable[Response | None]]): self._serve(path, "GET", func) async def wrapper(*args, **kwargs): - res: Response = await func(*args, **kwargs) - res.no_header = True + res = await func(*args, **kwargs) + if isinstance(res, Response): + res.no_header = True return res self._serve(path, "HEAD", wrapper) @@ -128,7 +133,7 @@ class App: def POST(self, path: str): """Decorator to register a POST HTTP route.""" - def decorator(func: Callable[[Request, ...], Awaitable[Response]]): + def decorator(func: Callable[[Request, ...], Awaitable[Response | None]]): self._serve(path, "POST", func) return func @@ -137,14 +142,14 @@ class App: def PUT(self, path: str): """Decorator to register a PUT HTTP route.""" - def decorator(func: Callable[[Request, ...], Awaitable[Response]]): + def decorator(func: Callable[[Request, ...], Awaitable[Response | None]]): self._serve(path, "PUT", func) return func def DELETE(self, path: str): """Decorator to register a DELETE HTTP route.""" - def decorator(func: Callable[[Request, ...], Awaitable[Response]]): + def decorator(func: Callable[[Request, ...], Awaitable[Response | None]]): self._serve(path, "DELETE", func) return func @@ -153,7 +158,7 @@ class App: def PATCH(self, path: str): """Decorator to register a PATCH HTTP route.""" - def decorator(func: Callable[[Request, ...], Awaitable[Response]]): + def decorator(func: Callable[[Request, ...], Awaitable[Response | None]]): self._serve(path, "PATCH", func) return func @@ -162,7 +167,7 @@ class App: def HEAD(self, path: str): """Decorator to register a HEAD HTTP route.""" - def decorator(func: Callable[[Request, ...], Awaitable[Response]]): + def decorator(func: Callable[[Request, ...], Awaitable[Response | None]]): self._serve(path, "HEAD", func) return func @@ -180,17 +185,17 @@ class App: def resolve(self, path, method) -> tuple[Callable[..., Awaitable[Response]], dict]: for pattern, route in self.routes.items(): if m := pattern.fullmatch(path): - if method not in route: + if method not in route or method == "websocket": return self.error_routes[405], {} return route[method], { k: urllib.parse.unquote(v) for k, v in m.groupdict().items() - } + } # ty:ignore[invalid-return-type] return self.error_routes[404], {} def methods(self, path) -> set[str]: for pattern, route in self.routes.items(): if pattern.fullmatch(path): - return set(route.keys()) + return set(route.keys() - {"websocket"}) return set() async def handle_client( @@ -203,6 +208,12 @@ class App: if not request_line: return + if not request_line.decode(encoding="utf-8").startswith( + ("GET", "POST", "PUT", "HEAD", "DELETE", "PATCH", "OPTIONS") + ): + # Probably WebSocket + pass + # Parse request line parts = request_line.decode(encoding="utf-8").strip().split() if len(parts) < 3: @@ -251,6 +262,17 @@ class App: else: route, kwargs = self.resolve(path, method) + if ( + method == "GET" + and "connection" in headers + and headers.get("connection") == "Upgrade" + and "upgrade" in headers + ): + # Upgrade + return Response( + status=426, headers=["Connection: close"], body=b"" + ).render(self) + response: Response = await route( request=Request( method=method, path=path, headers=headers, body=body