100 lines
3.4 KiB
Python
100 lines
3.4 KiB
Python
from collections import defaultdict
|
|
from datetime import date, datetime
|
|
from typing import Iterable
|
|
|
|
import requests
|
|
|
|
from regiojet_search.cache import cache
|
|
from regiojet_search.db import fetch_results, save_results
|
|
from regiojet_search.models import City, Country, Fare, Result
|
|
from regiojet_search.settings import LOCATIONS_URL, SEARCH_URL
|
|
|
|
|
|
def fetch_cities() -> list[City]:
|
|
cache_key = "cities"
|
|
cached_cities = cache.get(cache_key)
|
|
|
|
if cached_cities is not None:
|
|
return [City(**cached_city) for cached_city in cached_cities]
|
|
|
|
cities: list[City] = []
|
|
response = requests.get(LOCATIONS_URL)
|
|
|
|
for raw_country in response.json():
|
|
country = Country(code=raw_country["code"], name=raw_country["country"])
|
|
cities.extend(
|
|
[City(id=raw_city["id"], name=raw_city["name"], country=country) for raw_city in raw_country["cities"]]
|
|
)
|
|
|
|
cache.set(cache_key, cities)
|
|
|
|
return cities
|
|
|
|
|
|
def search(from_city: City, to_city: City, departure_date: date, currency: str) -> list[Result]:
|
|
slug = Result.slugify(from_city.name, to_city.name, departure_date, currency)
|
|
cache_key = f"journey:{slug}"
|
|
cached_results = cache.get(cache_key)
|
|
|
|
if cached_results is not None:
|
|
return [Result(**cached_result) for cached_result in cached_results]
|
|
|
|
if results_from_db := fetch_results(slug):
|
|
persist_results(results_from_db, save_to_database=False)
|
|
return results_from_db
|
|
|
|
response = requests.get(
|
|
SEARCH_URL,
|
|
headers={"X-Currency": currency.upper()},
|
|
params={
|
|
"tariffs": "REGULAR",
|
|
"fromLocationType": "CITY",
|
|
"toLocationType": "CITY",
|
|
"fromLocationId": from_city.id,
|
|
"toLocationId": to_city.id,
|
|
"departureDate": departure_date.isoformat(),
|
|
},
|
|
)
|
|
|
|
results = list(parse_results(response.json()["routes"], from_city, to_city, currency=currency))
|
|
|
|
persist_results(results)
|
|
|
|
return discard_different_departure_dates(results, departure_date)
|
|
|
|
|
|
def parse_results(raw_results: list[dict], from_city: City, to_city: City, currency: str) -> Iterable[Result]:
|
|
for raw_result in raw_results:
|
|
yield Result(
|
|
departure=datetime.fromisoformat(raw_result["departureTime"]),
|
|
arrival=datetime.fromisoformat(raw_result["arrivalTime"]),
|
|
origin=from_city.name,
|
|
destination=to_city.name,
|
|
fare=Fare(
|
|
amount=raw_result["priceFrom"],
|
|
currency=currency,
|
|
),
|
|
type=raw_result["vehicleTypes"][0].lower(),
|
|
source_id=raw_result["departureStationId"],
|
|
destination_id=raw_result["arrivalStationId"],
|
|
free_seats=raw_result["freeSeatsCount"],
|
|
carrier=raw_result["vehicleStandardKey"],
|
|
)
|
|
|
|
|
|
def discard_different_departure_dates(results: list[Result], departure: date) -> list[Result]:
|
|
return [result for result in results if result.departure.date() == departure]
|
|
|
|
|
|
def persist_results(results: list[Result], save_to_redis: bool = True, save_to_database: bool = True) -> None:
|
|
buckets: dict[str, list[Result]] = defaultdict(list)
|
|
|
|
for result in results:
|
|
buckets[result.slug].append(result)
|
|
|
|
for slug, subset_of_results in buckets.items():
|
|
if save_to_database:
|
|
save_results(subset_of_results)
|
|
if save_to_redis:
|
|
cache.set(slug, subset_of_results)
|